add speechio results

This commit is contained in:
Yuekai Zhang 2024-03-07 14:44:38 +08:00
parent b422e7a97f
commit a00c0c5279
36 changed files with 600 additions and 1518 deletions

View File

@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 | | fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
```bash ```bash
./prepare.sh ./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1" export CUDA_VISIBLE_DEVICES="0,1"

View File

@ -1,6 +1,6 @@
## Results ## Results
### Aishell2 char-based training results ### Aishell2 char-based training results
#### Pruned transducer stateless 5 #### Pruned transducer stateless 5

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,7 +49,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): def compute_fbank_aishell2(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(8, os.cpu_count()) num_jobs = min(8, os.cpu_count())
@ -69,7 +78,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
dataset_parts, dataset_parts,
) )
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -84,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -129,5 +140,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell2( compute_fbank_aishell2(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
nj=30 nj=30
stage=1 stage=0
stop_stage=1 stop_stage=7
perturb_speed=true perturb_speed=true

View File

@ -3,7 +3,7 @@
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets). This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
(From [Open Speech and Language Resources](https://www.openslr.org/111/)) (From [Open Speech and Language Resources](https://www.openslr.org/111/))

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import LilcomChunkyWriter, CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,7 +49,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): def compute_fbank_aishell4(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/aishell4") src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(8, os.cpu_count()) num_jobs = min(8, os.cpu_count())
@ -71,7 +80,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
) )
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -87,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -100,7 +111,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
executor=ex, executor=ex,
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,
) )
logging.info("About splitting cuts into smaller chunks") logging.info("About splitting cuts into smaller chunks")
cut_set = cut_set.trim_to_supervisions( cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, keep_overlapping=False,
@ -140,5 +151,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell4( compute_fbank_aishell4(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=20 stage=-1
stop_stage=20 stop_stage=7
perturb_speed=true perturb_speed=true

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,7 +49,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): def compute_fbank_alimeeting(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/alimeeting") src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(8, os.cpu_count()) num_jobs = min(8, os.cpu_count())
@ -71,7 +80,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
) )
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -86,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -140,5 +151,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_alimeeting( compute_fbank_alimeeting(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=20 stage=-1
stop_stage=20 stop_stage=7
perturb_speed=true perturb_speed=true
# We assume dl_dir (download dir) contains the following # We assume dl_dir (download dir) contains the following

View File

@ -36,4 +36,4 @@ This recipe includes scripts for training Zipformer model using multiple Chinese
3. AliMeeting 3. AliMeeting
4. MagicData 4. MagicData
5. KeSpeech-ASR 5. KeSpeech-ASR
6. WeNetSpeech 6. WeNetSpeech

View File

@ -17,14 +17,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import argparse
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from icefall.utils import str2bool from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect # Do this outside of main() in case it needs to take effect
@ -32,6 +40,7 @@ from icefall.utils import str2bool
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -52,6 +61,7 @@ def get_parser():
) )
return parser return parser
def compute_fbank_kespeech_dev_test(args): def compute_fbank_kespeech_dev_test(args):
in_out_dir = Path("data/fbank/kespeech") in_out_dir = Path("data/fbank/kespeech")
# number of workers in dataloader # number of workers in dataloader
@ -70,7 +80,9 @@ def compute_fbank_kespeech_dev_test(args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
if args.whisper_fbank: if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
)
else: else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))

View File

@ -25,16 +25,17 @@ from pathlib import Path
import torch import torch
from lhotse import ( from lhotse import (
CutSet, CutSet,
WhisperFbank,
WhisperFbankConfig,
KaldifeatFbank, KaldifeatFbank,
KaldifeatFbankConfig, KaldifeatFbankConfig,
LilcomChunkyWriter, LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
set_audio_duration_mismatch_tolerance, set_audio_duration_mismatch_tolerance,
set_caching_enabled, set_caching_enabled,
) )
from icefall.utils import str2bool from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect # Do this outside of main() in case it needs to take effect
@ -129,7 +130,9 @@ def compute_fbank_kespeech_splits(args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
if args.whisper_fbank: if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else: else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}") logging.info(f"device: {device}")

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,6 +49,7 @@ from icefall.utils import get_executor, str2bool
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -62,7 +70,10 @@ def get_parser():
) )
return parser return parser
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
def compute_fbank_magicdata(
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/magicdata") src_dir = Path("data/manifests/magicdata")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(8, os.cpu_count()) num_jobs = min(8, os.cpu_count())
@ -84,9 +95,11 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False,
list(manifests.keys()), list(manifests.keys()),
dataset_parts, dataset_parts,
) )
if args.whisper_fbank: if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -145,5 +158,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_magicdata( compute_fbank_magicdata(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
speed_perturb=args.speed_perturb,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -43,7 +50,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): def compute_fbank_primewords(
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/primewords") src_dir = Path("data/manifests/primewords")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -65,9 +74,11 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
list(manifests.keys()), list(manifests.keys()),
dataset_parts, dataset_parts,
) )
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -128,5 +139,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_primewords( compute_fbank_primewords(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
speed_perturb=args.speed_perturb,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -43,7 +50,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): def compute_fbank_stcmds(
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/stcmds") src_dir = Path("data/manifests/stcmds")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -67,7 +76,9 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, wh
) )
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -126,5 +137,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_stcmds( compute_fbank_stcmds(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
speed_perturb=args.speed_perturb,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -43,7 +50,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False): def compute_fbank_thchs30(
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/thchs30") src_dir = Path("data/manifests/thchs30")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -71,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, w
) )
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -132,5 +143,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_thchs30( compute_fbank_thchs30(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
speed_perturb=args.speed_perturb,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=121 stage=-1
stop_stage=121 stop_stage=100
num_splits=100 num_splits=100
dl_dir=$PWD/download dl_dir=$PWD/download
@ -95,10 +95,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_dev.jsonl.gz) . ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_dev.jsonl.gz) .
ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_test.jsonl.gz) . ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts_test.jsonl.gz) .
cd ../.. cd ../..
else else
log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
exit 1 exit 1
fi fi
fi fi
log "Dataset: AISHELL-4" log "Dataset: AISHELL-4"
@ -115,10 +115,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
cd ../.. cd ../..
else else
log "Abort! Please run ../../aishell4/ASR/prepare.sh --stage 3 --stop-stage 3" log "Abort! Please run ../../aishell4/ASR/prepare.sh --stage 3 --stop-stage 3"
exit 1 exit 1
fi fi
fi fi
log "Dataset: ST-CMDS" log "Dataset: ST-CMDS"
@ -261,7 +261,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
if [ ! -f data/manifests/.kespeech.done ]; then if [ ! -f data/manifests/.kespeech.done ]; then
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech lhotse prepare kespeech -j 16 $dl_dir/KeSpeech data/manifests/kespeech
touch data/manifests/.kespeech.done touch data/manifests/.kespeech.done
fi fi
@ -272,8 +272,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then
python3 ./local/preprocess_kespeech.py python3 ./local/preprocess_kespeech.py
touch data/fbank/.kespeech_preprocess_complete touch data/fbank/.kespeech_preprocess_complete
fi fi
if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then
log "Spliting KeSpeech train_phase1" log "Spliting KeSpeech train_phase1"
lhotse split ${num_splits} \ lhotse split ${num_splits} \
@ -281,7 +281,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
data/fbank/kespeech/train_phase1_split_${num_splits} data/fbank/kespeech/train_phase1_split_${num_splits}
touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done
fi fi
if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then
log "Spliting KeSpeech train_phase2" log "Spliting KeSpeech train_phase2"
lhotse split ${num_splits} \ lhotse split ${num_splits} \
@ -289,7 +289,7 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
data/fbank/kespeech/train_phase2_split_${num_splits} data/fbank/kespeech/train_phase2_split_${num_splits}
touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done
fi fi
log "Compute KeSpeech fbank for train_phase1" log "Compute KeSpeech fbank for train_phase1"
./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase1 ./local/compute_fbank_kespeech_splits.py --speed-perturb true --num-splits ${num_splits} --training-subset train_phase1
@ -314,7 +314,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
if [ ! -f data/manifests/.kespeech.done ]; then if [ ! -f data/manifests/.kespeech.done ]; then
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech lhotse prepare kespeech -j 8 $dl_dir/KeSpeech data/manifests/kespeech
touch data/manifests/.kespeech.done touch data/manifests/.kespeech.done
fi fi
@ -325,8 +325,8 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then if [ ! -f data/fbank/.kespeech_preprocess_complete ]; then
python3 ./local/preprocess_kespeech.py --speed-perturb true python3 ./local/preprocess_kespeech.py --speed-perturb true
touch data/fbank/.kespeech_preprocess_complete touch data/fbank/.kespeech_preprocess_complete
fi fi
if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then if [ ! -f data/fbank/.kespeech.train_phase1.split.${num_splits}.done ]; then
log "Spliting KeSpeech train_phase1" log "Spliting KeSpeech train_phase1"
lhotse split ${num_splits} \ lhotse split ${num_splits} \
@ -334,7 +334,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
data/fbank/kespeech/train_phase1_split_${num_splits} data/fbank/kespeech/train_phase1_split_${num_splits}
touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done touch data/fbank/.kespeech.train_phase1.split.${num_splits}.done
fi fi
if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then if [ ! -f data/fbank/.kespeech.train_phase2.split.${num_splits}.done ]; then
log "Spliting KeSpeech train_phase2" log "Spliting KeSpeech train_phase2"
lhotse split ${num_splits} \ lhotse split ${num_splits} \
@ -342,7 +342,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
data/fbank/kespeech/train_phase2_split_${num_splits} data/fbank/kespeech/train_phase2_split_${num_splits}
touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done touch data/fbank/.kespeech.train_phase2.split.${num_splits}.done
fi fi
log "Compute KeSpeech fbank for train_phase1" log "Compute KeSpeech fbank for train_phase1"
./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true ./local/compute_fbank_kespeech_splits.py --num-splits ${num_splits} --training-subset train_phase1 --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
@ -351,7 +351,7 @@ if [ $stage -le 120 ] && [ $stop_stage -ge 120 ]; then
log "Compute KeSpeech fbank for test/dev" log "Compute KeSpeech fbank for test/dev"
# ./local/compute_fbank_kespeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true # ./local/compute_fbank_kespeech_dev_test.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then
pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz") pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz")
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
@ -422,7 +422,7 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size} lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir mkdir -p $lang_dir
if [ ! -f $lang_dir/bpe.model ]; then if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \ ./local/train_bpe_model.py \
@ -442,7 +442,7 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
--lexicon $lang_dir/lexicon.txt \ --lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bpe.model --bpe-model $lang_dir/bpe.model
fi fi
if [ ! -f $lang_dir/L.fst ]; then if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst" log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \ ./shared/convert-k2-to-openfst.py \
@ -463,7 +463,7 @@ fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Prepare G (note that we use ngram lm of wenetspeech only for G preparation)" log "Stage 14: Prepare G (note that we use ngram lm of wenetspeech only for G preparation)"
if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then if [ -d ../../wenetspeech/ASR/data/lang_char/ ]; then
cd data cd data
ln -s ../../../../wenetspeech/ASR/data/lm . ln -s ../../../../wenetspeech/ASR/data/lm .
@ -482,5 +482,3 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
python ./local/compile_lg.py --lang-dir $lang_dir python ./local/compile_lg.py --lang-dir $lang_dir
done done
fi fi

View File

@ -52,14 +52,14 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert from zhconv import convert
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (

View File

@ -34,10 +34,10 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
--model-name medium --model-name medium
""" """
import os
import argparse import argparse
import copy import copy
import logging import logging
import os
import random import random
import warnings import warnings
from pathlib import Path from pathlib import Path
@ -52,13 +52,13 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from multi_dataset import MultiDataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from multi_dataset import MultiDataset
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
@ -626,7 +626,9 @@ def train_one_epoch(
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
) )
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}") os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
@ -761,9 +763,7 @@ def run(rank, world_size, args):
del model.alignment_heads del model.alignment_heads
if params.pretrained_model_path: if params.pretrained_model_path:
checkpoint = torch.load( checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
params.pretrained_model_path, map_location="cpu"
)
if "model" not in checkpoint: if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True) model.load_state_dict(checkpoint, strict=True)
else: else:
@ -866,7 +866,7 @@ def run(rank, world_size, args):
valid_cuts = multi_dataset.dev_cuts() valid_cuts = multi_dataset.dev_cuts()
valid_dl = data_module.valid_dataloaders(valid_cuts) valid_dl = data_module.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")

View File

@ -0,0 +1,15 @@
# Introduction
This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets.
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Pretrained Models
The following table lists the pretrained models.
| | Huggingface | Comment |
|---------------------------------------|--------------------|-----------------------------|
| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | |
| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training |

View File

@ -0,0 +1,92 @@
## Results
### SpeechIO Test Set Decoding Results
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
Command for decoding using fine-tuned whisper:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2_wenetspeech \
--model-name large-v2 \
--epoch 999 --avg 1 \
--start-index 0 --end-index 26 \
--remove-whisper-encoder-input-length-restriction True \
--manifest-dir data/fbank \
--beam-size 1 --max-duration 50
```
Command for decoding using pretrained zipformer:
```bash
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
git lfs pull --include "exp/pretrained.pt"
git lfs pull --include "data/lang_bpe_2000/*"
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data
wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt
mv words.txt ./data/lang_bpe_2000/
./zipformer/decode.py \
--epoch 999 \
--avg 1 \
--blank-penalty 2.0 \
--use-averaged-model false \
--exp-dir ./zipformer/exp_pretrain \
--max-duration 600 \
--start-index 0 --end-index 26 \
--manifest-dir data/fbank_kaldi \
--decoding-method greedy_search
```
Command for fusion the above decoding results from whisper and zipformer:
```bash
python local/whisper_zipformer_fusion.py \
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
--output-log-dir ./results_fusion
```
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
SpeechIO fbank features, decoding scripts, logs, and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_speechio>

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -44,7 +51,13 @@ torch.set_num_interop_threads(1)
SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source. SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source.
def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False, fbank_dir: str = "data/fbank", whisper_fbank: bool = False):
def compute_fbank_speechio(
num_mel_bins: int = 80,
speed_perturb: bool = False,
fbank_dir: str = "data/fbank",
whisper_fbank: bool = False,
):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path(fbank_dir) output_dir = Path(fbank_dir)
num_jobs = min(8, os.cpu_count()) num_jobs = min(8, os.cpu_count())
@ -72,7 +85,9 @@ def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False,
) )
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -127,5 +142,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_speechio( compute_fbank_speechio(
num_mel_bins=args.num_mel_bins, fbank_dir=args.fbank_dir, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
fbank_dir=args.fbank_dir,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -35,15 +35,18 @@ def main():
idx = f"{i}".zfill(2) idx = f"{i}".zfill(2)
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
prefix="speechio" prefix = "speechio"
suffix="jsonl.gz" suffix = "jsonl.gz"
for partition in dataset_parts: for partition in dataset_parts:
path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}" path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}"
cuts = load_manifest_lazy(path) cuts = load_manifest_lazy(path)
print(f"===================Duration statistics of {partition}===================") print(
f"===================Duration statistics of {partition}==================="
)
cuts.describe() cuts.describe()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -0,0 +1,217 @@
#!/usr/bin/env python3
#
# Copyright 2024 Author: Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file uses whisper and zipformer decoding results to generate fusion decoding results.
Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors,
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
Usage:
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
"""
import argparse
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import kaldialign
from icefall.utils import store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--whisper-log-dir",
type=str,
default="./recogs_whisper",
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
)
parser.add_argument(
"--zipformer-log-dir",
type=str,
default="./recogs_zipformer",
help="The directory to store the zipformer logs",
)
parser.add_argument(
"--output-log-dir",
type=str,
default="./results_fusion",
help="The directory to store the fusion logs",
)
return parser
def save_results(
res_dir: Path,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
suffix = "epoch-999-avg-1"
for key, results in results_dict.items():
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
print(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
print("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
print(s)
def extract_hyp_ref_wavname(filename):
"""
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
"""
hyps, refs, wav_name = [], [], []
with open(filename, "r") as f:
for line in f:
if "ref" in line:
ref = line.split("ref=")[1].strip()
ref = ref[2:-2]
list_elements = ref.split("', '")
ref = "".join(list_elements)
refs.append(ref)
elif "hyp" in line:
hyp = line.split("hyp=")[1].strip()
hyps.append(hyp)
wav_name.append(line.split(":")[0])
return hyps, refs, wav_name
def get_pair_filenames(
whisper_log_dir,
zipformer_log_dir,
whisper_suffix="beam-search-epoch-999-avg-1",
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
):
results = []
start_index, end_index = 0, 26
dataset_parts = []
for i in range(start_index, end_index + 1):
idx = f"{i}".zfill(2)
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
for partition in dataset_parts:
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
zipformer_filename = (
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
)
results.append((whisper_filename, zipformer_filename))
return results
def fusion_hyps_trust_substituion_insertion(
hyps_whisper, hyps_zipformer, refs, ERR="*"
):
"""
alignment example:
[('', ''), ('', '*'), ('', ''), ('', ''), ('', ''), ('', ''), ('*', '')]
left is whisper, right is zipformer
for whisper substitution, use left
for whisper insertion, use left
for whisper deletion, use right
"""
hyps_fusion = []
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
ali = kaldialign.align(hyp_w, hyp_z, ERR)
hyp_f = ""
for a in ali:
if a[0] == ERR:
hyp_f += a[1]
else:
hyp_f += a[0]
hyps_fusion.append(hyp_f)
return hyps_fusion
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
"""
alignment example:
[('', ''), ('', '*'), ('', ''), ('', ''), ('', ''), ('', ''), ('*', '')]
left is whisper, right is zipformer
for whisper substitution, use left
for whisper insertion, use right
for whisper deletion, use right
"""
hyps_fusion = []
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
ali = kaldialign.align(hyp_w, hyp_z, ERR)
hyp_f = ""
for a in ali:
if a[0] == ERR:
hyp_f += a[1]
elif a[1] == ERR:
pass
else:
hyp_f += a[0]
hyps_fusion.append(hyp_f)
return hyps_fusion
def main():
parser = get_parser()
args = parser.parse_args()
# mkdir output_log_dir
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
for pair in pair_logs:
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
hyps_fusion = fusion_hyps_trust_substituion_insertion(
hyps_whisper, hyps_zipformer, refs
)
partition_name = pair[0].split("/")[-1].split("-")[1]
save_results(
Path(args.output_log_dir),
partition_name,
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
)
print(f"Processed {partition_name}")
if __name__ == "__main__":
main()

View File

@ -12,7 +12,7 @@ stop_stage=3
# - $dl_dir/SPEECHIO_ASR_ZH00000 # - $dl_dir/SPEECHIO_ASR_ZH00000
# This directory contains the following files downloaded from # This directory contains the following files downloaded from
# https://github.com/SpeechColab/Leaderboard # https://github.com/SpeechColab/Leaderboard
# #
# - metadata.tsv # - metadata.tsv
# - wav # - wav
# - wav.scp # - wav.scp

View File

@ -34,9 +34,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
SimpleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
AudioSamples,
)
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader

View File

@ -53,14 +53,14 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert from zhconv import convert
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (

View File

@ -45,17 +45,15 @@ class MultiDataset:
idx = f"{i}".zfill(2) idx = f"{i}".zfill(2)
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
prefix="speechio" prefix = "speechio"
suffix="jsonl.gz" suffix = "jsonl.gz"
results_dict = {} results_dict = {}
for partition in dataset_parts: for partition in dataset_parts:
path = f"{prefix}_cuts_{partition}.{suffix}" path = f"{prefix}_cuts_{partition}.{suffix}"
logging.info(f"Loading {path} set in lazy mode") logging.info(f"Loading {path} set in lazy mode")
test_cuts = load_manifest_lazy( test_cuts = load_manifest_lazy(self.fbank_dir / path)
self.fbank_dir / path
)
results_dict[partition] = test_cuts results_dict[partition] = test_cuts
return results_dict return results_dict

View File

@ -303,6 +303,17 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -431,6 +442,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -455,6 +467,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
@ -468,8 +481,9 @@ def decode_one_batch(
) )
hyps.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp).split())
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
@ -657,6 +671,7 @@ def main():
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../multi_zh-hans/ASR/zipformer/train.py

View File

@ -16,11 +16,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import argparse
import torch import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
@ -32,6 +40,7 @@ torch.multiprocessing.set_sharing_strategy("file_system")
from icefall.utils import str2bool from icefall.utils import str2bool
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -52,6 +61,7 @@ def get_parser():
) )
return parser return parser
def compute_fbank_wenetspeech_dev_test(args): def compute_fbank_wenetspeech_dev_test(args):
in_out_dir = Path("data/fbank") in_out_dir = Path("data/fbank")
# number of workers in dataloader # number of workers in dataloader
@ -66,7 +76,9 @@ def compute_fbank_wenetspeech_dev_test(args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
if args.whisper_fbank: if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else: else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))

View File

@ -22,20 +22,19 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import ( from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig,
CutSet, CutSet,
WhisperFbank,
WhisperFbankConfig,
# KaldifeatWhisperFbank,
# KaldifeatWhisperFbankConfig,
KaldifeatFbank, KaldifeatFbank,
KaldifeatFbankConfig, KaldifeatFbankConfig,
LilcomChunkyWriter, LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
set_audio_duration_mismatch_tolerance, set_audio_duration_mismatch_tolerance,
set_caching_enabled, set_caching_enabled,
) )
from icefall.utils import str2bool, get_executor from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect # Do this outside of main() in case it needs to take effect
@ -148,11 +147,11 @@ def compute_fbank_wenetspeech_splits(args):
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
set_caching_enabled(False) set_caching_enabled(False)
#with get_executor() as ex: # Initialize the executor only once. # with get_executor() as ex: # Initialize the executor only once.
for i in range(start, stop): for i in range(start, stop):
idx = f"{i}".zfill(num_digits) idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {i+1}/{num_splits}") logging.info(f"Processing {i+1}/{num_splits}")
cuts_path = output_dir / f"cuts_{subset}.{idx}.jsonl.gz" cuts_path = output_dir / f"cuts_{subset}.{idx}.jsonl.gz"
if cuts_path.is_file(): if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping") logging.info(f"{cuts_path} exists - skipping")
@ -177,13 +176,6 @@ def compute_fbank_wenetspeech_splits(args):
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,
overwrite=True, overwrite=True,
) )
# cut_set = cut_set.compute_and_store_features(
# extractor=extractor,
# storage_path=f"{output_dir}/feats_{subset}_{idx}",
# num_jobs=args.num_workers,
# executor=ex,
# storage_type=LilcomChunkyWriter,
# )
logging.info(f"Saving to {cuts_path}") logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path) cut_set.to_file(cuts_path)

View File

@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
nj=15 nj=15
stage=131 stage=0
stop_stage=131 stop_stage=100
# Split L subset to this number of pieces # Split L subset to this number of pieces
# This is to avoid OOM during feature extraction. # This is to avoid OOM during feature extraction.
@ -309,7 +309,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
mkdir -p $text_out_dir mkdir -p $text_out_dir
log "Genearating training text data" log "Genearating training text data"
if [ ! -f $text_out_dir/lm_data.pt ]; then if [ ! -f $text_out_dir/lm_data.pt ]; then
./local/prepare_char_lm_training_data.py \ ./local/prepare_char_lm_training_data.py \
--lang-char data/lang_char \ --lang-char data/lang_char \
@ -318,14 +318,14 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
fi fi
log "Generating DEV text data" log "Generating DEV text data"
# prepare validation text data # prepare validation text data
if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then
valid_text=${text_out_dir}/ valid_text=${text_out_dir}/
gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \ gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \
| jq '.text' | sed 's/"//g' \ | jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $text_out_dir/valid_text | ./local/text2token.py -t "char" > $text_out_dir/valid_text
python3 ./local/text2segments.py \ python3 ./local/text2segments.py \
--num-process $nj \ --num-process $nj \
--input-file $text_out_dir/valid_text \ --input-file $text_out_dir/valid_text \
@ -337,7 +337,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
--lm-data $text_out_dir/valid_text_words_segmentation \ --lm-data $text_out_dir/valid_text_words_segmentation \
--lm-archive $text_out_dir/lm_data_valid.pt --lm-archive $text_out_dir/lm_data_valid.pt
# prepare TEST text data # prepare TEST text data
if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then
log "Prepare text for test set." log "Prepare text for test set."
for test_set in TEST_MEETING TEST_NET; do for test_set in TEST_MEETING TEST_NET; do
@ -350,7 +350,7 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
--input-file $text_out_dir/${test_set}_text \ --input-file $text_out_dir/${test_set}_text \
--output-file $text_out_dir/${test_set}_text_words_segmentation --output-file $text_out_dir/${test_set}_text_words_segmentation
done done
cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation
fi fi
@ -401,4 +401,4 @@ if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
--lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \
--vocab-size 5537 \ --vocab-size 5537 \
--master-port 12340 --master-port 12340
fi fi

View File

@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse import argparse
import logging import logging
from icefall import is_module_available import torch
from onnx_pretrained import OnnxModel from onnx_pretrained import OnnxModel
import torch from icefall import is_module_available
def get_parser(): def get_parser():

View File

@ -52,13 +52,13 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import whisper import whisper
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import WenetSpeechAsrDataModule
from lhotse.cut import Cut
from tn.chinese.normalizer import Normalizer from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert from zhconv import convert
from lhotse.cut import Cut
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (

View File

@ -834,6 +834,7 @@ def run(rank, world_size, args):
# ) # )
return False return False
return True return True
train_cuts = wenetspeech.train_cuts() train_cuts = wenetspeech.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = wenetspeech.train_dataloaders(train_cuts) train_dl = wenetspeech.train_dataloaders(train_cuts)