mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add speechio results
This commit is contained in:
parent
b422e7a97f
commit
a00c0c5279
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
def compute_fbank_aishell2(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(8, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
@ -69,7 +78,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -84,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -129,5 +140,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell2(
|
compute_fbank_aishell2(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
nj=30
|
nj=30
|
||||||
stage=1
|
stage=0
|
||||||
stop_stage=1
|
stop_stage=7
|
||||||
perturb_speed=true
|
perturb_speed=true
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import LilcomChunkyWriter, CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
def compute_fbank_aishell4(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/aishell4")
|
src_dir = Path("data/manifests/aishell4")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(8, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
@ -71,7 +80,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
|
|||||||
)
|
)
|
||||||
|
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -87,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -140,5 +151,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell4(
|
compute_fbank_aishell4(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
|
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=20
|
stage=-1
|
||||||
stop_stage=20
|
stop_stage=7
|
||||||
perturb_speed=true
|
perturb_speed=true
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
def compute_fbank_alimeeting(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/alimeeting")
|
src_dir = Path("data/manifests/alimeeting")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(8, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
@ -71,7 +80,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
|||||||
)
|
)
|
||||||
|
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -86,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -140,5 +151,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_alimeeting(
|
compute_fbank_alimeeting(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
|
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=20
|
stage=-1
|
||||||
stop_stage=20
|
stop_stage=7
|
||||||
perturb_speed=true
|
perturb_speed=true
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
# We assume dl_dir (download dir) contains the following
|
||||||
|
@ -17,14 +17,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
KaldifeatFbank,
|
||||||
|
KaldifeatFbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
# Do this outside of main() in case it needs to take effect
|
# Do this outside of main() in case it needs to take effect
|
||||||
@ -32,6 +40,7 @@ from icefall.utils import str2bool
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -52,6 +61,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_kespeech_dev_test(args):
|
def compute_fbank_kespeech_dev_test(args):
|
||||||
in_out_dir = Path("data/fbank/kespeech")
|
in_out_dir = Path("data/fbank/kespeech")
|
||||||
# number of workers in dataloader
|
# number of workers in dataloader
|
||||||
@ -70,7 +80,9 @@ def compute_fbank_kespeech_dev_test(args):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
if args.whisper_fbank:
|
if args.whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||||
|
|
||||||
|
@ -25,16 +25,17 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
WhisperFbank,
|
|
||||||
WhisperFbankConfig,
|
|
||||||
KaldifeatFbank,
|
KaldifeatFbank,
|
||||||
KaldifeatFbankConfig,
|
KaldifeatFbankConfig,
|
||||||
LilcomChunkyWriter,
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
set_audio_duration_mismatch_tolerance,
|
set_audio_duration_mismatch_tolerance,
|
||||||
set_caching_enabled,
|
set_caching_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
# Do this outside of main() in case it needs to take effect
|
# Do this outside of main() in case it needs to take effect
|
||||||
@ -129,7 +130,9 @@ def compute_fbank_kespeech_splits(args):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
if args.whisper_fbank:
|
if args.whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
@ -30,7 +30,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,6 +49,7 @@ from icefall.utils import get_executor, str2bool
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -62,7 +70,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
|
||||||
|
def compute_fbank_magicdata(
|
||||||
|
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/magicdata")
|
src_dir = Path("data/manifests/magicdata")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(8, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
@ -86,7 +97,9 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False,
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.whisper_fbank:
|
if args.whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda"))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -145,5 +158,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_magicdata(
|
compute_fbank_magicdata(
|
||||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
speed_perturb=args.speed_perturb,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
def compute_fbank_primewords(
|
||||||
|
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/primewords")
|
src_dir = Path("data/manifests/primewords")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -67,7 +76,9 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
|
|||||||
)
|
)
|
||||||
|
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -128,5 +139,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_primewords(
|
compute_fbank_primewords(
|
||||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
speed_perturb=args.speed_perturb,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
def compute_fbank_stcmds(
|
||||||
|
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/stcmds")
|
src_dir = Path("data/manifests/stcmds")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -67,7 +76,9 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, wh
|
|||||||
)
|
)
|
||||||
|
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -126,5 +137,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_stcmds(
|
compute_fbank_stcmds(
|
||||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
speed_perturb=args.speed_perturb,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
def compute_fbank_thchs30(
|
||||||
|
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests/thchs30")
|
src_dir = Path("data/manifests/thchs30")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -71,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, w
|
|||||||
)
|
)
|
||||||
|
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -132,5 +143,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_thchs30(
|
compute_fbank_thchs30(
|
||||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
speed_perturb=args.speed_perturb,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
|
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=121
|
stage=-1
|
||||||
stop_stage=121
|
stop_stage=100
|
||||||
num_splits=100
|
num_splits=100
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
@ -482,5 +482,3 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
|||||||
python ./local/compile_lg.py --lang-dir $lang_dir
|
python ./local/compile_lg.py --lang-dir $lang_dir
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,14 +52,14 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from multi_dataset import MultiDataset
|
||||||
from tn.chinese.normalizer import Normalizer
|
from tn.chinese.normalizer import Normalizer
|
||||||
from whisper.normalizers import BasicTextNormalizer
|
from whisper.normalizers import BasicTextNormalizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from zhconv import convert
|
from zhconv import convert
|
||||||
from lhotse.cut import Cut
|
|
||||||
from multi_dataset import MultiDataset
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
|
@ -34,10 +34,10 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
|||||||
--model-name medium
|
--model-name medium
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -52,13 +52,13 @@ import torch.multiprocessing as mp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import whisper
|
import whisper
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from multi_dataset import MultiDataset
|
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from label_smoothing import LabelSmoothingLoss
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from multi_dataset import MultiDataset
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
@ -626,7 +626,9 @@ def train_one_epoch(
|
|||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
)
|
)
|
||||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}")
|
os.system(
|
||||||
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
@ -761,9 +763,7 @@ def run(rank, world_size, args):
|
|||||||
del model.alignment_heads
|
del model.alignment_heads
|
||||||
|
|
||||||
if params.pretrained_model_path:
|
if params.pretrained_model_path:
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||||
params.pretrained_model_path, map_location="cpu"
|
|
||||||
)
|
|
||||||
if "model" not in checkpoint:
|
if "model" not in checkpoint:
|
||||||
model.load_state_dict(checkpoint, strict=True)
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
else:
|
else:
|
||||||
|
15
egs/speechio/ASR/README.md
Normal file
15
egs/speechio/ASR/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
# Introduction
|
||||||
|
|
||||||
|
This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets.
|
||||||
|
|
||||||
|
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||||
|
|
||||||
|
# Pretrained Models
|
||||||
|
|
||||||
|
The following table lists the pretrained models.
|
||||||
|
|
||||||
|
| | Huggingface | Comment |
|
||||||
|
|---------------------------------------|--------------------|-----------------------------|
|
||||||
|
| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | |
|
||||||
|
| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training |
|
92
egs/speechio/ASR/RESULTS.md
Normal file
92
egs/speechio/ASR/RESULTS.md
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
## Results
|
||||||
|
|
||||||
|
### SpeechIO Test Set Decoding Results
|
||||||
|
|
||||||
|
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
|
||||||
|
|
||||||
|
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|
||||||
|
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
|
||||||
|
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
|
||||||
|
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
|
||||||
|
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
|
||||||
|
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
|
||||||
|
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
|
||||||
|
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
|
||||||
|
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
|
||||||
|
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
|
||||||
|
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
|
||||||
|
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
|
||||||
|
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
|
||||||
|
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
|
||||||
|
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
|
||||||
|
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
|
||||||
|
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
|
||||||
|
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
|
||||||
|
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
|
||||||
|
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
|
||||||
|
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
|
||||||
|
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
|
||||||
|
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
|
||||||
|
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
|
||||||
|
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
|
||||||
|
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
|
||||||
|
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
|
||||||
|
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
|
||||||
|
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
|
||||||
|
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Command for decoding using fine-tuned whisper:
|
||||||
|
```bash
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper
|
||||||
|
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt
|
||||||
|
|
||||||
|
python3 ./whisper/decode.py \
|
||||||
|
--exp-dir whisper/exp_large_v2_wenetspeech \
|
||||||
|
--model-name large-v2 \
|
||||||
|
--epoch 999 --avg 1 \
|
||||||
|
--start-index 0 --end-index 26 \
|
||||||
|
--remove-whisper-encoder-input-length-restriction True \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--beam-size 1 --max-duration 50
|
||||||
|
```
|
||||||
|
Command for decoding using pretrained zipformer:
|
||||||
|
```bash
|
||||||
|
git lfs install
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
|
||||||
|
cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
git lfs pull --include "data/lang_bpe_2000/*"
|
||||||
|
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt
|
||||||
|
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data
|
||||||
|
wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt
|
||||||
|
mv words.txt ./data/lang_bpe_2000/
|
||||||
|
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--blank-penalty 2.0 \
|
||||||
|
--use-averaged-model false \
|
||||||
|
--exp-dir ./zipformer/exp_pretrain \
|
||||||
|
--max-duration 600 \
|
||||||
|
--start-index 0 --end-index 26 \
|
||||||
|
--manifest-dir data/fbank_kaldi \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
```
|
||||||
|
Command for fusion the above decoding results from whisper and zipformer:
|
||||||
|
```bash
|
||||||
|
python local/whisper_zipformer_fusion.py \
|
||||||
|
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
|
||||||
|
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
|
||||||
|
--output-log-dir ./results_fusion
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
|
||||||
|
|
||||||
|
SpeechIO fbank features, decoding scripts, logs, and decoding results
|
||||||
|
are available at
|
||||||
|
<https://huggingface.co/yuekai/icefall_asr_speechio>
|
@ -30,7 +30,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -44,7 +51,13 @@ torch.set_num_interop_threads(1)
|
|||||||
|
|
||||||
SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source.
|
SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source.
|
||||||
|
|
||||||
def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False, fbank_dir: str = "data/fbank", whisper_fbank: bool = False):
|
|
||||||
|
def compute_fbank_speechio(
|
||||||
|
num_mel_bins: int = 80,
|
||||||
|
speed_perturb: bool = False,
|
||||||
|
fbank_dir: str = "data/fbank",
|
||||||
|
whisper_fbank: bool = False,
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path(fbank_dir)
|
output_dir = Path(fbank_dir)
|
||||||
num_jobs = min(8, os.cpu_count())
|
num_jobs = min(8, os.cpu_count())
|
||||||
@ -72,7 +85,9 @@ def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False,
|
|||||||
)
|
)
|
||||||
|
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -127,5 +142,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_speechio(
|
compute_fbank_speechio(
|
||||||
num_mel_bins=args.num_mel_bins, fbank_dir=args.fbank_dir, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
fbank_dir=args.fbank_dir,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -35,15 +35,18 @@ def main():
|
|||||||
idx = f"{i}".zfill(2)
|
idx = f"{i}".zfill(2)
|
||||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||||
|
|
||||||
prefix="speechio"
|
prefix = "speechio"
|
||||||
suffix="jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
|
|
||||||
for partition in dataset_parts:
|
for partition in dataset_parts:
|
||||||
path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}"
|
path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}"
|
||||||
cuts = load_manifest_lazy(path)
|
cuts = load_manifest_lazy(path)
|
||||||
print(f"===================Duration statistics of {partition}===================")
|
print(
|
||||||
|
f"===================Duration statistics of {partition}==================="
|
||||||
|
)
|
||||||
cuts.describe()
|
cuts.describe()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
217
egs/speechio/ASR/local/whisper_zipformer_fusion.py
Normal file
217
egs/speechio/ASR/local/whisper_zipformer_fusion.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2024 Author: Yuekai Zhang
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This file uses whisper and zipformer decoding results to generate fusion decoding results.
|
||||||
|
Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors,
|
||||||
|
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import kaldialign
|
||||||
|
|
||||||
|
from icefall.utils import store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-log-dir",
|
||||||
|
type=str,
|
||||||
|
default="./recogs_whisper",
|
||||||
|
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--zipformer-log-dir",
|
||||||
|
type=str,
|
||||||
|
default="./recogs_zipformer",
|
||||||
|
help="The directory to store the zipformer logs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-log-dir",
|
||||||
|
type=str,
|
||||||
|
default="./results_fusion",
|
||||||
|
help="The directory to store the fusion logs",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
res_dir: Path,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
|
||||||
|
suffix = "epoch-999-avg-1"
|
||||||
|
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
print(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
print("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
print(s)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_hyp_ref_wavname(filename):
|
||||||
|
"""
|
||||||
|
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
|
||||||
|
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
|
||||||
|
"""
|
||||||
|
hyps, refs, wav_name = [], [], []
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
if "ref" in line:
|
||||||
|
ref = line.split("ref=")[1].strip()
|
||||||
|
ref = ref[2:-2]
|
||||||
|
list_elements = ref.split("', '")
|
||||||
|
ref = "".join(list_elements)
|
||||||
|
refs.append(ref)
|
||||||
|
elif "hyp" in line:
|
||||||
|
hyp = line.split("hyp=")[1].strip()
|
||||||
|
hyps.append(hyp)
|
||||||
|
wav_name.append(line.split(":")[0])
|
||||||
|
return hyps, refs, wav_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_pair_filenames(
|
||||||
|
whisper_log_dir,
|
||||||
|
zipformer_log_dir,
|
||||||
|
whisper_suffix="beam-search-epoch-999-avg-1",
|
||||||
|
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
|
||||||
|
):
|
||||||
|
results = []
|
||||||
|
start_index, end_index = 0, 26
|
||||||
|
dataset_parts = []
|
||||||
|
for i in range(start_index, end_index + 1):
|
||||||
|
idx = f"{i}".zfill(2)
|
||||||
|
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||||
|
for partition in dataset_parts:
|
||||||
|
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
|
||||||
|
zipformer_filename = (
|
||||||
|
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
|
||||||
|
)
|
||||||
|
results.append((whisper_filename, zipformer_filename))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def fusion_hyps_trust_substituion_insertion(
|
||||||
|
hyps_whisper, hyps_zipformer, refs, ERR="*"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
alignment example:
|
||||||
|
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||||
|
left is whisper, right is zipformer
|
||||||
|
for whisper substitution, use left
|
||||||
|
for whisper insertion, use left
|
||||||
|
for whisper deletion, use right
|
||||||
|
"""
|
||||||
|
hyps_fusion = []
|
||||||
|
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||||
|
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||||
|
hyp_f = ""
|
||||||
|
for a in ali:
|
||||||
|
if a[0] == ERR:
|
||||||
|
hyp_f += a[1]
|
||||||
|
else:
|
||||||
|
hyp_f += a[0]
|
||||||
|
hyps_fusion.append(hyp_f)
|
||||||
|
return hyps_fusion
|
||||||
|
|
||||||
|
|
||||||
|
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
|
||||||
|
"""
|
||||||
|
alignment example:
|
||||||
|
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||||
|
left is whisper, right is zipformer
|
||||||
|
for whisper substitution, use left
|
||||||
|
for whisper insertion, use right
|
||||||
|
for whisper deletion, use right
|
||||||
|
"""
|
||||||
|
hyps_fusion = []
|
||||||
|
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||||
|
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||||
|
hyp_f = ""
|
||||||
|
for a in ali:
|
||||||
|
if a[0] == ERR:
|
||||||
|
hyp_f += a[1]
|
||||||
|
elif a[1] == ERR:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
hyp_f += a[0]
|
||||||
|
hyps_fusion.append(hyp_f)
|
||||||
|
return hyps_fusion
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
# mkdir output_log_dir
|
||||||
|
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
|
||||||
|
for pair in pair_logs:
|
||||||
|
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
|
||||||
|
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
|
||||||
|
|
||||||
|
hyps_fusion = fusion_hyps_trust_substituion_insertion(
|
||||||
|
hyps_whisper, hyps_zipformer, refs
|
||||||
|
)
|
||||||
|
|
||||||
|
partition_name = pair[0].split("/")[-1].split("-")[1]
|
||||||
|
save_results(
|
||||||
|
Path(args.output_log_dir),
|
||||||
|
partition_name,
|
||||||
|
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Processed {partition_name}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -34,9 +34,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
SimpleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
|
||||||
AudioSamples,
|
|
||||||
)
|
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
@ -53,14 +53,14 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from multi_dataset import MultiDataset
|
||||||
from tn.chinese.normalizer import Normalizer
|
from tn.chinese.normalizer import Normalizer
|
||||||
from whisper.normalizers import BasicTextNormalizer
|
from whisper.normalizers import BasicTextNormalizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from zhconv import convert
|
from zhconv import convert
|
||||||
from lhotse.cut import Cut
|
|
||||||
from multi_dataset import MultiDataset
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
|
@ -45,17 +45,15 @@ class MultiDataset:
|
|||||||
idx = f"{i}".zfill(2)
|
idx = f"{i}".zfill(2)
|
||||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||||
|
|
||||||
prefix="speechio"
|
prefix = "speechio"
|
||||||
suffix="jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
|
|
||||||
results_dict = {}
|
results_dict = {}
|
||||||
for partition in dataset_parts:
|
for partition in dataset_parts:
|
||||||
path = f"{prefix}_cuts_{partition}.{suffix}"
|
path = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
|
||||||
logging.info(f"Loading {path} set in lazy mode")
|
logging.info(f"Loading {path} set in lazy mode")
|
||||||
test_cuts = load_manifest_lazy(
|
test_cuts = load_manifest_lazy(self.fbank_dir / path)
|
||||||
self.fbank_dir / path
|
|
||||||
)
|
|
||||||
results_dict[partition] = test_cuts
|
results_dict[partition] = test_cuts
|
||||||
|
|
||||||
return results_dict
|
return results_dict
|
@ -303,6 +303,17 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -431,6 +442,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -455,6 +467,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
@ -468,8 +481,9 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
|
key = f"blank_penalty_{params.blank_penalty}"
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search_" + key: hyps}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
key = f"beam_{params.beam}_"
|
key = f"beam_{params.beam}_"
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
@ -657,6 +671,7 @@ def main():
|
|||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
1
egs/speechio/ASR/zipformer/train.py
Symbolic link
1
egs/speechio/ASR/zipformer/train.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../multi_zh-hans/ASR/zipformer/train.py
|
@ -16,11 +16,19 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import argparse
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
KaldifeatFbank,
|
||||||
|
KaldifeatFbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -32,6 +40,7 @@ torch.multiprocessing.set_sharing_strategy("file_system")
|
|||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -52,6 +61,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_wenetspeech_dev_test(args):
|
def compute_fbank_wenetspeech_dev_test(args):
|
||||||
in_out_dir = Path("data/fbank")
|
in_out_dir = Path("data/fbank")
|
||||||
# number of workers in dataloader
|
# number of workers in dataloader
|
||||||
@ -66,7 +76,9 @@ def compute_fbank_wenetspeech_dev_test(args):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
if args.whisper_fbank:
|
if args.whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||||
|
|
||||||
|
@ -22,20 +22,19 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import (
|
from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig,
|
||||||
CutSet,
|
CutSet,
|
||||||
WhisperFbank,
|
|
||||||
WhisperFbankConfig,
|
|
||||||
# KaldifeatWhisperFbank,
|
|
||||||
# KaldifeatWhisperFbankConfig,
|
|
||||||
KaldifeatFbank,
|
KaldifeatFbank,
|
||||||
KaldifeatFbankConfig,
|
KaldifeatFbankConfig,
|
||||||
LilcomChunkyWriter,
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
set_audio_duration_mismatch_tolerance,
|
set_audio_duration_mismatch_tolerance,
|
||||||
set_caching_enabled,
|
set_caching_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.utils import str2bool, get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
# Do this outside of main() in case it needs to take effect
|
# Do this outside of main() in case it needs to take effect
|
||||||
@ -148,7 +147,7 @@ def compute_fbank_wenetspeech_splits(args):
|
|||||||
|
|
||||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||||
set_caching_enabled(False)
|
set_caching_enabled(False)
|
||||||
#with get_executor() as ex: # Initialize the executor only once.
|
# with get_executor() as ex: # Initialize the executor only once.
|
||||||
for i in range(start, stop):
|
for i in range(start, stop):
|
||||||
idx = f"{i}".zfill(num_digits)
|
idx = f"{i}".zfill(num_digits)
|
||||||
logging.info(f"Processing {i+1}/{num_splits}")
|
logging.info(f"Processing {i+1}/{num_splits}")
|
||||||
@ -177,13 +176,6 @@ def compute_fbank_wenetspeech_splits(args):
|
|||||||
storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
# cut_set = cut_set.compute_and_store_features(
|
|
||||||
# extractor=extractor,
|
|
||||||
# storage_path=f"{output_dir}/feats_{subset}_{idx}",
|
|
||||||
# num_jobs=args.num_workers,
|
|
||||||
# executor=ex,
|
|
||||||
# storage_type=LilcomChunkyWriter,
|
|
||||||
# )
|
|
||||||
logging.info(f"Saving to {cuts_path}")
|
logging.info(f"Saving to {cuts_path}")
|
||||||
cut_set.to_file(cuts_path)
|
cut_set.to_file(cuts_path)
|
||||||
|
|
||||||
|
@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
nj=15
|
nj=15
|
||||||
stage=131
|
stage=0
|
||||||
stop_stage=131
|
stop_stage=100
|
||||||
|
|
||||||
# Split L subset to this number of pieces
|
# Split L subset to this number of pieces
|
||||||
# This is to avoid OOM during feature extraction.
|
# This is to avoid OOM during feature extraction.
|
||||||
|
@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -52,13 +52,13 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import WenetSpeechAsrDataModule
|
||||||
|
from lhotse.cut import Cut
|
||||||
from tn.chinese.normalizer import Normalizer
|
from tn.chinese.normalizer import Normalizer
|
||||||
from whisper.normalizers import BasicTextNormalizer
|
from whisper.normalizers import BasicTextNormalizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from zhconv import convert
|
from zhconv import convert
|
||||||
from lhotse.cut import Cut
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
|
@ -834,6 +834,7 @@ def run(rank, world_size, args):
|
|||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = wenetspeech.train_cuts()
|
train_cuts = wenetspeech.train_cuts()
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user