mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add speechio results
This commit is contained in:
parent
b422e7a97f
commit
a00c0c5279
@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
|
||||
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
|
||||
|
||||
```bash
|
||||
./prepare.sh
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
## Results
|
||||
|
||||
### Aishell2 char-based training results
|
||||
### Aishell2 char-based training results
|
||||
|
||||
#### Pruned transducer stateless 5
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_aishell2(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -69,7 +78,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
dataset_parts,
|
||||
)
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -84,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -129,5 +140,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell2(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
nj=30
|
||||
stage=1
|
||||
stop_stage=1
|
||||
stage=0
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
|
||||
|
||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
|
||||
|
||||
(From [Open Speech and Language Resources](https://www.openslr.org/111/))
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import LilcomChunkyWriter, CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_aishell4(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/aishell4")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -71,7 +80,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -87,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -100,7 +111,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
|
||||
logging.info("About splitting cuts into smaller chunks")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False,
|
||||
@ -140,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell4(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=20
|
||||
stop_stage=20
|
||||
stage=-1
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_alimeeting(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/alimeeting")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -71,7 +80,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -86,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -140,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_alimeeting(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=20
|
||||
stop_stage=20
|
||||
stage=-1
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
|
@ -36,4 +36,4 @@ This recipe includes scripts for training Zipformer model using multiple Chinese
|
||||
3. AliMeeting
|
||||
4. MagicData
|
||||
5. KeSpeech-ASR
|
||||
6. WeNetSpeech
|
||||
6. WeNetSpeech
|
||||
|
@ -17,14 +17,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -32,6 +40,7 @@ from icefall.utils import str2bool
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -52,6 +61,7 @@ def get_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_kespeech_dev_test(args):
|
||||
in_out_dir = Path("data/fbank/kespeech")
|
||||
# number of workers in dataloader
|
||||
@ -70,7 +80,9 @@ def compute_fbank_kespeech_dev_test(args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
|
@ -25,16 +25,17 @@ from pathlib import Path
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -129,7 +130,9 @@ def compute_fbank_kespeech_splits(args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,6 +49,7 @@ from icefall.utils import get_executor, str2bool
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -62,7 +70,10 @@ def get_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
|
||||
def compute_fbank_magicdata(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/magicdata")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -84,9 +95,11 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False,
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda"))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -145,5 +158,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_magicdata(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_primewords(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/primewords")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -65,9 +74,11 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -128,5 +139,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_primewords(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_stcmds(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/stcmds")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -67,7 +76,9 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, wh
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -126,5 +137,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_stcmds(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_thchs30(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/thchs30")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -71,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, w
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -132,5 +143,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_thchs30(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=121
|
||||
stop_stage=121
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
num_splits=100
|
||||
|
||||
dl_dir=$PWD/download
|
||||
@ -95,10 +95,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_dev.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_test.jsonl.gz) .
|
||||
cd ../..
|
||||
else
|
||||
else
|
||||
log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
log "Dataset: AISHELL-4"
|
||||
@ -115,10 +115,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
|
||||
cd ../..
|
||||
else
|
||||
else
|
||||
log "Abort! Please run ../../aishell4/ASR/prepare.sh --stage 3 --stop-stage 3"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
log "Dataset: ST-CMDS"
|
||||
@ -261,7 +261,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
|
||||
if [ ! -f data/manifests/.kespeech.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech
|
||||
lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech
|
||||
touch data/manifests/.kespeech.done
|
||||
fi
|
||||
|
||||
@ -272,8 +272,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then
|
||||
python3 ./local/preprocess_kespeech.py
|
||||
touch data/fbank/.kespeech_preprocess_complete
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then
|
||||
log "Spliting KeSpeech train_phase1"
|
||||
lhotse split ${num_splits} \
|
||||
@ -281,7 +281,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
data/fbank/kespeech/train_phase1_split_${num_splits}
|
||||
touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then
|
||||
log "Spliting KeSpeech train_phase2"
|
||||
lhotse split ${num_splits} \
|
||||
@ -289,7 +289,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
data/fbank/kespeech/train_phase2_split_${num_splits}
|
||||
touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done
|
||||
fi
|
||||
|
||||
|
||||
log "Compute KeSpeech fbank for train_phase1"
|
||||
./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase1
|
||||
|
||||
@ -314,7 +314,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
|
||||
|
||||
if [ ! -f data/manifests/.kespeech.done ]; then
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech
|
||||
lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech
|
||||
touch data/manifests/.kespeech.done
|
||||
fi
|
||||
|
||||
@ -325,8 +325,8 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
|
||||
if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then
|
||||
python3 ./local/preprocess_kespeech.py --speed-perturb true
|
||||
touch data/fbank/.kespeech_preprocess_complete
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then
|
||||
log "Spliting KeSpeech train_phase1"
|
||||
lhotse split ${num_splits} \
|
||||
@ -334,7 +334,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
|
||||
data/fbank/kespeech/train_phase1_split_${num_splits}
|
||||
touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then
|
||||
log "Spliting KeSpeech train_phase2"
|
||||
lhotse split ${num_splits} \
|
||||
@ -342,7 +342,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
|
||||
data/fbank/kespeech/train_phase2_split_${num_splits}
|
||||
touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done
|
||||
fi
|
||||
|
||||
|
||||
log "Compute KeSpeech fbank for train_phase1"
|
||||
./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
|
||||
@ -351,7 +351,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
|
||||
|
||||
log "Compute KeSpeech fbank for test/dev"
|
||||
# ./local/compute_fbank_kespeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
|
||||
|
||||
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
||||
@ -422,7 +422,7 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
|
||||
|
||||
mkdir -p $lang_dir
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
@ -442,7 +442,7 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--bpe-model $lang_dir/bpe.model
|
||||
fi
|
||||
|
||||
|
||||
if [ ! -f $lang_dir/L.fst ]; then
|
||||
log "Converting L.pt to L.fst"
|
||||
./shared/convert-k2-to-openfst.py \
|
||||
@ -463,7 +463,7 @@ fi
|
||||
|
||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
log "Stage 14: Prepare G (note that we use ngram lm of wenetspeech only for G preparation)"
|
||||
|
||||
|
||||
if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then
|
||||
cd data
|
||||
ln -s ../../../../wenetspeech/ASR/data/lm .
|
||||
@ -482,5 +482,3 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
python ./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
|
@ -52,14 +52,14 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
|
@ -34,10 +34,10 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--model-name medium
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@ -52,13 +52,13 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from multi_dataset import MultiDataset
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from multi_dataset import MultiDataset
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
@ -626,7 +626,9 @@ def train_one_epoch(
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
)
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}")
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
@ -761,9 +763,7 @@ def run(rank, world_size, args):
|
||||
del model.alignment_heads
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(
|
||||
params.pretrained_model_path, map_location="cpu"
|
||||
)
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
@ -866,7 +866,7 @@ def run(rank, world_size, args):
|
||||
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
|
15
egs/speechio/ASR/README.md
Normal file
15
egs/speechio/ASR/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets.
|
||||
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||
|
||||
# Pretrained Models
|
||||
|
||||
The following table lists the pretrained models.
|
||||
|
||||
| | Huggingface | Comment |
|
||||
|---------------------------------------|--------------------|-----------------------------|
|
||||
| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | |
|
||||
| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training |
|
92
egs/speechio/ASR/RESULTS.md
Normal file
92
egs/speechio/ASR/RESULTS.md
Normal file
@ -0,0 +1,92 @@
|
||||
## Results
|
||||
|
||||
### SpeechIO Test Set Decoding Results
|
||||
|
||||
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
|
||||
|
||||
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|
||||
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
|
||||
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
|
||||
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
|
||||
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
|
||||
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
|
||||
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
|
||||
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
|
||||
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
|
||||
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
|
||||
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
|
||||
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
|
||||
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
|
||||
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
|
||||
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
|
||||
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
|
||||
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
|
||||
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
|
||||
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
|
||||
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
|
||||
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
|
||||
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
|
||||
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
|
||||
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
|
||||
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
|
||||
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
|
||||
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
|
||||
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
|
||||
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
|
||||
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
|
||||
|
||||
|
||||
|
||||
|
||||
Command for decoding using fine-tuned whisper:
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2_wenetspeech \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--start-index 0 --end-index 26 \
|
||||
--remove-whisper-encoder-input-length-restriction True \
|
||||
--manifest-dir data/fbank \
|
||||
--beam-size 1 --max-duration 50
|
||||
```
|
||||
Command for decoding using pretrained zipformer:
|
||||
```bash
|
||||
git lfs install
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
|
||||
cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
git lfs pull --include "data/lang_bpe_2000/*"
|
||||
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt
|
||||
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data
|
||||
wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt
|
||||
mv words.txt ./data/lang_bpe_2000/
|
||||
|
||||
./zipformer/decode.py \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--blank-penalty 2.0 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./zipformer/exp_pretrain \
|
||||
--max-duration 600 \
|
||||
--start-index 0 --end-index 26 \
|
||||
--manifest-dir data/fbank_kaldi \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
Command for fusion the above decoding results from whisper and zipformer:
|
||||
```bash
|
||||
python local/whisper_zipformer_fusion.py \
|
||||
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
|
||||
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
|
||||
--output-log-dir ./results_fusion
|
||||
|
||||
```
|
||||
|
||||
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
|
||||
|
||||
SpeechIO fbank features, decoding scripts, logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/yuekai/icefall_asr_speechio>
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -44,7 +51,13 @@ torch.set_num_interop_threads(1)
|
||||
|
||||
SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source.
|
||||
|
||||
def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False, fbank_dir: str = "data/fbank", whisper_fbank: bool = False):
|
||||
|
||||
def compute_fbank_speechio(
|
||||
num_mel_bins: int = 80,
|
||||
speed_perturb: bool = False,
|
||||
fbank_dir: str = "data/fbank",
|
||||
whisper_fbank: bool = False,
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path(fbank_dir)
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -72,7 +85,9 @@ def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False,
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -127,5 +142,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_speechio(
|
||||
num_mel_bins=args.num_mel_bins, fbank_dir=args.fbank_dir, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
fbank_dir=args.fbank_dir,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -35,15 +35,18 @@ def main():
|
||||
idx = f"{i}".zfill(2)
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
|
||||
prefix="speechio"
|
||||
suffix="jsonl.gz"
|
||||
prefix = "speechio"
|
||||
suffix = "jsonl.gz"
|
||||
|
||||
for partition in dataset_parts:
|
||||
path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}"
|
||||
cuts = load_manifest_lazy(path)
|
||||
print(f"===================Duration statistics of {partition}===================")
|
||||
print(
|
||||
f"===================Duration statistics of {partition}==================="
|
||||
)
|
||||
cuts.describe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
217
egs/speechio/ASR/local/whisper_zipformer_fusion.py
Normal file
217
egs/speechio/ASR/local/whisper_zipformer_fusion.py
Normal file
@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2024 Author: Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file uses whisper and zipformer decoding results to generate fusion decoding results.
|
||||
Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors,
|
||||
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
|
||||
|
||||
Usage:
|
||||
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import kaldialign
|
||||
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-log-dir",
|
||||
type=str,
|
||||
default="./recogs_whisper",
|
||||
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zipformer-log-dir",
|
||||
type=str,
|
||||
default="./recogs_zipformer",
|
||||
help="The directory to store the zipformer logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-log-dir",
|
||||
type=str,
|
||||
default="./results_fusion",
|
||||
help="The directory to store the fusion logs",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def save_results(
|
||||
res_dir: Path,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
|
||||
suffix = "epoch-999-avg-1"
|
||||
|
||||
for key, results in results_dict.items():
|
||||
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
print(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
print("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
print(s)
|
||||
|
||||
|
||||
def extract_hyp_ref_wavname(filename):
|
||||
"""
|
||||
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
|
||||
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
|
||||
"""
|
||||
hyps, refs, wav_name = [], [], []
|
||||
with open(filename, "r") as f:
|
||||
for line in f:
|
||||
if "ref" in line:
|
||||
ref = line.split("ref=")[1].strip()
|
||||
ref = ref[2:-2]
|
||||
list_elements = ref.split("', '")
|
||||
ref = "".join(list_elements)
|
||||
refs.append(ref)
|
||||
elif "hyp" in line:
|
||||
hyp = line.split("hyp=")[1].strip()
|
||||
hyps.append(hyp)
|
||||
wav_name.append(line.split(":")[0])
|
||||
return hyps, refs, wav_name
|
||||
|
||||
|
||||
def get_pair_filenames(
|
||||
whisper_log_dir,
|
||||
zipformer_log_dir,
|
||||
whisper_suffix="beam-search-epoch-999-avg-1",
|
||||
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
|
||||
):
|
||||
results = []
|
||||
start_index, end_index = 0, 26
|
||||
dataset_parts = []
|
||||
for i in range(start_index, end_index + 1):
|
||||
idx = f"{i}".zfill(2)
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
for partition in dataset_parts:
|
||||
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
|
||||
zipformer_filename = (
|
||||
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
|
||||
)
|
||||
results.append((whisper_filename, zipformer_filename))
|
||||
return results
|
||||
|
||||
|
||||
def fusion_hyps_trust_substituion_insertion(
|
||||
hyps_whisper, hyps_zipformer, refs, ERR="*"
|
||||
):
|
||||
"""
|
||||
alignment example:
|
||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||
left is whisper, right is zipformer
|
||||
for whisper substitution, use left
|
||||
for whisper insertion, use left
|
||||
for whisper deletion, use right
|
||||
"""
|
||||
hyps_fusion = []
|
||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||
hyp_f = ""
|
||||
for a in ali:
|
||||
if a[0] == ERR:
|
||||
hyp_f += a[1]
|
||||
else:
|
||||
hyp_f += a[0]
|
||||
hyps_fusion.append(hyp_f)
|
||||
return hyps_fusion
|
||||
|
||||
|
||||
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
|
||||
"""
|
||||
alignment example:
|
||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||
left is whisper, right is zipformer
|
||||
for whisper substitution, use left
|
||||
for whisper insertion, use right
|
||||
for whisper deletion, use right
|
||||
"""
|
||||
hyps_fusion = []
|
||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||
hyp_f = ""
|
||||
for a in ali:
|
||||
if a[0] == ERR:
|
||||
hyp_f += a[1]
|
||||
elif a[1] == ERR:
|
||||
pass
|
||||
else:
|
||||
hyp_f += a[0]
|
||||
hyps_fusion.append(hyp_f)
|
||||
return hyps_fusion
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
# mkdir output_log_dir
|
||||
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
|
||||
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
|
||||
for pair in pair_logs:
|
||||
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
|
||||
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
|
||||
|
||||
hyps_fusion = fusion_hyps_trust_substituion_insertion(
|
||||
hyps_whisper, hyps_zipformer, refs
|
||||
)
|
||||
|
||||
partition_name = pair[0].split("/")[-1].split("-")[1]
|
||||
save_results(
|
||||
Path(args.output_log_dir),
|
||||
partition_name,
|
||||
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
|
||||
)
|
||||
|
||||
print(f"Processed {partition_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -12,7 +12,7 @@ stop_stage=3
|
||||
# - $dl_dir/SPEECHIO_ASR_ZH00000
|
||||
# This directory contains the following files downloaded from
|
||||
# https://github.com/SpeechColab/Leaderboard
|
||||
#
|
||||
#
|
||||
# - metadata.tsv
|
||||
# - wav
|
||||
# - wav.scp
|
||||
|
@ -34,9 +34,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -53,14 +53,14 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
|
@ -45,17 +45,15 @@ class MultiDataset:
|
||||
idx = f"{i}".zfill(2)
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
|
||||
prefix="speechio"
|
||||
suffix="jsonl.gz"
|
||||
prefix = "speechio"
|
||||
suffix = "jsonl.gz"
|
||||
|
||||
results_dict = {}
|
||||
for partition in dataset_parts:
|
||||
path = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
|
||||
logging.info(f"Loading {path} set in lazy mode")
|
||||
test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / path
|
||||
)
|
||||
test_cuts = load_manifest_lazy(self.fbank_dir / path)
|
||||
results_dict[partition] = test_cuts
|
||||
|
||||
return results_dict
|
||||
return results_dict
|
||||
|
@ -303,6 +303,17 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -431,6 +442,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -455,6 +467,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
@ -468,8 +481,9 @@ def decode_one_batch(
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search_" + key: hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -657,6 +671,7 @@ def main():
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
1
egs/speechio/ASR/zipformer/train.py
Symbolic link
1
egs/speechio/ASR/zipformer/train.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../multi_zh-hans/ASR/zipformer/train.py
|
@ -16,11 +16,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -32,6 +40,7 @@ torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -52,6 +61,7 @@ def get_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_wenetspeech_dev_test(args):
|
||||
in_out_dir = Path("data/fbank")
|
||||
# number of workers in dataloader
|
||||
@ -66,7 +76,9 @@ def compute_fbank_wenetspeech_dev_test(args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
|
@ -22,20 +22,19 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig,
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
# KaldifeatWhisperFbank,
|
||||
# KaldifeatWhisperFbankConfig,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool, get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -148,11 +147,11 @@ def compute_fbank_wenetspeech_splits(args):
|
||||
|
||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||
set_caching_enabled(False)
|
||||
#with get_executor() as ex: # Initialize the executor only once.
|
||||
# with get_executor() as ex: # Initialize the executor only once.
|
||||
for i in range(start, stop):
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
logging.info(f"Processing {i+1}/{num_splits}")
|
||||
|
||||
|
||||
cuts_path = output_dir / f"cuts_{subset}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
@ -177,13 +176,6 @@ def compute_fbank_wenetspeech_splits(args):
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
# cut_set = cut_set.compute_and_store_features(
|
||||
# extractor=extractor,
|
||||
# storage_path=f"{output_dir}/feats_{subset}_{idx}",
|
||||
# num_jobs=args.num_workers,
|
||||
# executor=ex,
|
||||
# storage_type=LilcomChunkyWriter,
|
||||
# )
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
nj=15
|
||||
stage=131
|
||||
stop_stage=131
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
# Split L subset to this number of pieces
|
||||
# This is to avoid OOM during feature extraction.
|
||||
@ -309,7 +309,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
mkdir -p $text_out_dir
|
||||
|
||||
log "Genearating training text data"
|
||||
|
||||
|
||||
if [ ! -f $text_out_dir/lm_data.pt ]; then
|
||||
./local/prepare_char_lm_training_data.py \
|
||||
--lang-char data/lang_char \
|
||||
@ -318,14 +318,14 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
fi
|
||||
|
||||
log "Generating DEV text data"
|
||||
# prepare validation text data
|
||||
# prepare validation text data
|
||||
if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then
|
||||
valid_text=${text_out_dir}/
|
||||
|
||||
gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $text_out_dir/valid_text
|
||||
|
||||
|
||||
python3 ./local/text2segments.py \
|
||||
--num-process $nj \
|
||||
--input-file $text_out_dir/valid_text \
|
||||
@ -337,7 +337,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
--lm-data $text_out_dir/valid_text_words_segmentation \
|
||||
--lm-archive $text_out_dir/lm_data_valid.pt
|
||||
|
||||
# prepare TEST text data
|
||||
# prepare TEST text data
|
||||
if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then
|
||||
log "Prepare text for test set."
|
||||
for test_set in TEST_MEETING TEST_NET; do
|
||||
@ -350,7 +350,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
--input-file $text_out_dir/${test_set}_text \
|
||||
--output-file $text_out_dir/${test_set}_text_words_segmentation
|
||||
done
|
||||
|
||||
|
||||
cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation
|
||||
fi
|
||||
|
||||
@ -401,4 +401,4 @@ if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||
--lm-data-valid data/lm_char/sorted_lm_data-valid.pt \
|
||||
--vocab-size 5537 \
|
||||
--master-port 12340
|
||||
fi
|
||||
fi
|
||||
|
@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -52,13 +52,13 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
|
||||
from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
from lhotse.cut import Cut
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
|
@ -834,6 +834,7 @@ def run(rank, world_size, args):
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
train_cuts = wenetspeech.train_cuts()
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
||||
|
Loading…
x
Reference in New Issue
Block a user