Merge branch 'k2-fsa:master' into master

This commit is contained in:
rickychanhoyin 2022-08-31 13:00:17 +08:00 committed by GitHub
commit ebb1dea786
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 130 additions and 23 deletions

View File

@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -63,6 +63,13 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -63,6 +63,13 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def preprocess_giga_speech():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items(): for partition, m in manifests.items():
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"

View File

@ -66,6 +66,13 @@ def compute_fbank_librispeech():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -65,6 +65,8 @@ def compute_fbank_musan():
assert len(manifests) == len(dataset_parts), ( assert len(manifests) == len(dataset_parts), (
len(manifests), len(manifests),
len(dataset_parts), len(dataset_parts),
list(manifests.keys()),
dataset_parts,
) )
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"

View File

@ -68,6 +68,13 @@ def preprocess_giga_speech():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items(): for partition, m in manifests.items():
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"

View File

@ -181,7 +181,7 @@ def test_convert_scaled_to_non_scaled():
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U)) y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y) d1 = model.decoder(y)
d2 = model.decoder(y) d2 = converted_model.decoder(y)
assert torch.allclose(d1, d2) assert torch.allclose(d1, d2)

View File

@ -69,6 +69,13 @@ def compute_fbank_musan():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = src_dir / "cuts_musan.jsonl.gz" musan_cuts_path = src_dir / "cuts_musan.jsonl.gz"
if musan_cuts_path.is_file(): if musan_cuts_path.is_file():

View File

@ -62,6 +62,13 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -62,6 +62,13 @@ def compute_fbank_tedlium():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -63,6 +63,13 @@ def compute_fbank_timit():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.

View File

@ -23,6 +23,8 @@ from pathlib import Path
from lhotse import CutSet, SupervisionSegment from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall import setup_logger
# Similar text filtering and normalization procedure as in: # Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh # https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
@ -48,13 +50,17 @@ def preprocess_wenet_speech():
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
# Note: By default, we preprocess all sub-parts.
# You can delete those that you don't need.
# For instance, if you don't want to use the L subpart, just remove
# the line below containing "L"
dataset_parts = ( dataset_parts = (
"L",
"M",
"S",
"DEV", "DEV",
"TEST_NET", "TEST_NET",
"TEST_MEETING", "TEST_MEETING",
"S",
"M",
"L",
) )
logging.info("Loading manifest (may take 10 minutes)") logging.info("Loading manifest (may take 10 minutes)")
@ -66,6 +72,13 @@ def preprocess_wenet_speech():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items(): for partition, m in manifests.items():
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
@ -81,10 +94,13 @@ def preprocess_wenet_speech():
logging.info(f"Normalizing text in {partition}") logging.info(f"Normalizing text in {partition}")
for sup in m["supervisions"]: for sup in m["supervisions"]:
text = str(sup.text) text = str(sup.text)
logging.info(f"Original text: {text}") orig_text = text
sup.text = normalize_text(sup.text) sup.text = normalize_text(sup.text)
text = str(sup.text) text = str(sup.text)
logging.info(f"Normalize text: {text}") if len(orig_text) != len(text):
logging.info(
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
)
# Create long-recording cut manifests. # Create long-recording cut manifests.
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
@ -109,12 +125,10 @@ def preprocess_wenet_speech():
def main(): def main():
formatter = ( setup_logger(log_filename="./log-preprocess-wenetspeech")
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_wenet_speech() preprocess_wenet_speech()
logging.info("Done")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -81,7 +81,6 @@ For training with the S subset:
import argparse import argparse
import logging import logging
import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -120,8 +119,6 @@ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
] ]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -162,7 +159,7 @@ def get_parser():
default=0, default=0,
help="""Resume training from from this epoch. help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from If it is positive, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
""", """,
) )
@ -361,8 +358,8 @@ def get_params() -> AttributeDict:
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 10, "batch_idx_train": 0,
"log_interval": 1, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
@ -545,7 +542,7 @@ def compute_loss(
warmup: float = 1.0, warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute RNN-T loss given the model and its inputs.
Args: Args:
params: params:
Parameters for training. See :func:`get_params`. Parameters for training. See :func:`get_params`.
@ -573,7 +570,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts) y = graph_compiler.texts_to_ids(texts)
if type(y) == list: if isinstance(y, list):
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
else: else:
y = y.to(device) y = y.to(device)
@ -697,7 +694,6 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])

View File

@ -61,7 +61,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse import argparse
import copy import copy
import logging import logging
import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -103,8 +102,6 @@ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
] ]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
@ -684,7 +681,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts) y = graph_compiler.texts_to_ids(texts)
if type(y) == list: if isinstance(y, list):
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
else: else:
y = y.to(device) y = y.to(device)

View File

@ -47,6 +47,13 @@ def compute_fbank_yesno():
) )
assert manifests is not None assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = Fbank( extractor = Fbank(
FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins) FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
) )