mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Merge pull request #4 from reazon-research/musan-mls-clean-final
Musan mls clean final
This commit is contained in:
commit
36fc1f1d1e
@ -545,6 +545,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
**Multilingual LibriSpeech (MLS)** is a large multilingual corpus suitable for speech research. The dataset is derived from read audiobooks from LibriVox and consists of 8 languages - English, German, Dutch, Spanish, French, Italian, Portuguese, Polish. It includes about 44.5K hours of English and a total of about 6K hours for other languages. This icefall training recipe was created for the restructured version of the English split of the dataset available on Hugging Face below.
|
||||
|
||||
|
||||
|
||||
The dataset is available on Hugging Face. For more details, please visit:
|
||||
|
||||
- Dataset: https://huggingface.co/datasets/parler-tts/mls_eng
|
||||
@ -14,6 +13,7 @@ The dataset is available on Hugging Face. For more details, please visit:
|
||||
|
||||
## On-the-fly feature computation
|
||||
|
||||
This recipe currently only supports on-the-fly feature bank computation, since `lhotse` manifests and feature banks are not pre-calculated in this recipe. This should mean that the dataset can be streamed from Hugging Face, but we have not tested this yet. We may add a version that supports pre-calculating features to better match existing recipes.
|
||||
This recipe currently only supports on-the-fly feature bank computation, since `lhotse` manifests and feature banks are not pre-calculated in this recipe. This should mean that the dataset can be streamed from Hugging Face, but we have not tested this yet. We may add a version that supports pre-calculating features to better match existing recipes.\
|
||||
<br>
|
||||
|
||||
<!-- [./RESULTS.md](./RESULTS.md) contains the latest results. -->
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results. This MLS English recipe was primarily developed for use in the ```multi_ja_en``` Japanese-English bilingual pipeline, which is based on MLS English and ReazonSpeech.
|
||||
|
||||
41
egs/mls_english/ASR/RESULTS.md
Normal file
41
egs/mls_english/ASR/RESULTS.md
Normal file
@ -0,0 +1,41 @@
|
||||
## Results
|
||||
|
||||
### MLS-English training results (Non-streaming) on zipformer model
|
||||
|
||||
#### Non-streaming
|
||||
|
||||
**WER on Test Set (Epoch 20)**
|
||||
|
||||
| Type | Greedy | Beam search |
|
||||
|---------------|--------|-------------|
|
||||
| Non-streaming | 6.65 | 6.57 |
|
||||
|
||||
|
||||
The training command:
|
||||
|
||||
```
|
||||
./zipformer/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 9 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp \
|
||||
--lang-dir data/lang/bpe_2000/
|
||||
```
|
||||
|
||||
The decoding command:
|
||||
|
||||
```
|
||||
./zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--lang-dir data/lang/bpe_2000/ \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
|
||||
|
||||
The pre-trained model is available here : [reazon-research/mls-english
|
||||
](https://huggingface.co/reazon-research/mls-english)
|
||||
|
||||
|
||||
Please note that this recipe was developed primarily as the source of English input in the bilingual Japanese-English recipe `multi_ja_en`, which uses ReazonSpeech and MLS English.
|
||||
1
egs/mls_english/ASR/local/compute_fbank_musan.py
Symbolic link
1
egs/mls_english/ASR/local/compute_fbank_musan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compute_fbank_musan.py
|
||||
@ -180,7 +180,10 @@ class MLSEnglishHFAsrDataModule:
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
cuts_musan: Optional[CutSet] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -191,6 +194,13 @@ class MLSEnglishHFAsrDataModule:
|
||||
"""
|
||||
|
||||
transforms = []
|
||||
if cuts_musan is not None:
|
||||
logging.info("Enable MUSAN")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
input_transforms = []
|
||||
|
||||
if self.args.enable_spec_aug:
|
||||
@ -337,19 +347,19 @@ class MLSEnglishHFAsrDataModule:
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "mls_english_cuts_train.jsonl.gz"
|
||||
self.args.manifest_dir / "mls_eng_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "mls_english_cuts_dev.jsonl.gz"
|
||||
self.args.manifest_dir / "mls_eng_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> List[CutSet]:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "mls_english_cuts_test.jsonl.gz"
|
||||
self.args.manifest_dir / "mls_eng_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
@ -16,6 +16,14 @@ vocab_sizes=(2000) # You can add more sizes like (500 1000 2000) for comparison
|
||||
# Directory where dataset will be downloaded
|
||||
dl_dir=$PWD/download
|
||||
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
@ -32,7 +40,7 @@ log() {
|
||||
log "Starting MLS English data preparation"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download MLS English dataset"
|
||||
log "Stage 0: Download data"
|
||||
# Check if huggingface_hub is installed
|
||||
if ! python -c "import huggingface_hub" &> /dev/null; then
|
||||
log "huggingface_hub Python library not found. Installing it now..."
|
||||
@ -55,6 +63,15 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
else
|
||||
log "Dataset already exists at $dl_dir/mls_english. Skipping download."
|
||||
fi
|
||||
# If you ha`ve predownloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/musan ] ; then
|
||||
log "Downloading musan."
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
@ -73,7 +90,25 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare transcript for BPE training"
|
||||
log "Stage 2: Prepare musan manifest"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to $dl_dir/musan
|
||||
if [ ! -e data/manifests/.musan_prep.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan_prep.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank for musan"
|
||||
if [ ! -e data/manifests/.musan_fbank.done ]; then
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/manifests/.musan_fbank.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Prepare transcript for BPE training"
|
||||
if [ ! -f data/lang/transcript.txt ]; then
|
||||
log "Generating transcripts for BPE training"
|
||||
python local/utils/generate_transcript.py \
|
||||
@ -83,8 +118,8 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare BPE tokenizer"
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare BPE tokenizer"
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
log "Training BPE model with vocab_size=${vocab_size}"
|
||||
bpe_dir=data/lang/bpe_${vocab_size}
|
||||
@ -99,8 +134,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Show manifest statistics"
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Show manifest statistics"
|
||||
python local/display_manifest_statistics.py --manifest-dir data/manifests > data/manifests/manifest_statistics.txt
|
||||
cat data/manifests/manifest_statistics.txt
|
||||
fi
|
||||
|
||||
@ -1044,13 +1044,13 @@ def main():
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
mls_english_corpus = MLSEnglishHFAsrDataModule(args)
|
||||
mls_english_corpus.load_dataset(args.dataset_path)
|
||||
|
||||
# # dev_cuts = mls_english_corpus.dev_cuts()
|
||||
# test_cuts = mls_english_corpus.test_cuts()
|
||||
|
||||
# dev_dl = mls_english_corpus.test_dataloader()
|
||||
test_dl = mls_english_corpus.test_dataloader()
|
||||
test_cuts = mls_english_corpus.test_cuts()
|
||||
test_dl = mls_english_corpus.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
||||
@ -68,6 +68,7 @@ from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from lhotse import load_manifest
|
||||
from model import AsrModel
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
@ -1215,11 +1216,8 @@ def run(rank, world_size, args):
|
||||
return True
|
||||
|
||||
mls_english_corpus = MLSEnglishHFAsrDataModule(args)
|
||||
mls_english_corpus.load_dataset(args.dataset_path)
|
||||
|
||||
# train_cuts = mls_english_corpus.train_cuts()
|
||||
|
||||
# train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_cuts = mls_english_corpus.train_cuts()
|
||||
# mls_english_corpus.load_dataset(args.dataset_path)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
@ -1227,17 +1225,23 @@ def run(rank, world_size, args):
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
if args.enable_musan:
|
||||
musan_path = Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
|
||||
if musan_path.exists():
|
||||
cuts_musan = load_manifest(musan_path)
|
||||
logging.info(f"Loaded MUSAN manifest from {musan_path}")
|
||||
else:
|
||||
logging.warning(f"MUSAN manifest not found at {musan_path}, disabling MUSAN augmentation")
|
||||
cuts_musan = None
|
||||
else:
|
||||
cuts_musan = None
|
||||
|
||||
# train_dl = mls_english_corpus.train_dataloaders(
|
||||
# train_cuts, sampler_state_dict=sampler_state_dict
|
||||
# )
|
||||
train_dl = mls_english_corpus.train_dataloader(
|
||||
sampler_state_dict=sampler_state_dict
|
||||
train_dl = mls_english_corpus.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
# valid_cuts = mls_english_corpus.valid_cuts()
|
||||
# valid_dl = mls_english_corpus.valid_dataloader(valid_cuts)
|
||||
valid_dl = mls_english_corpus.valid_dataloader()
|
||||
valid_cuts = mls_english_corpus.valid_cuts()
|
||||
valid_dl = mls_english_corpus.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
|
||||
@ -1185,6 +1185,7 @@ def run(rank, world_size, args):
|
||||
train_cuts = multi_dataset.train_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
|
||||
# Keep only utterances greater than 1 second
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
@ -1241,6 +1242,7 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
|
||||
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
|
||||
@ -612,6 +612,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
|
||||
@ -1391,13 +1391,20 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
||||
return concat(ragged, eos_id, direction="right")
|
||||
|
||||
|
||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
def make_pad_mask(
|
||||
lengths: torch.Tensor,
|
||||
max_len: int = 0,
|
||||
pad_left: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
lengths:
|
||||
A 1-D tensor containing sentence lengths.
|
||||
max_len:
|
||||
The length of masks.
|
||||
pad_left:
|
||||
If ``False`` (default), padding is on the right.
|
||||
If ``True``, padding is on the left.
|
||||
Returns:
|
||||
Return a 2-D bool tensor, where masked positions
|
||||
are filled with `True` and non-masked positions are
|
||||
@ -1414,9 +1421,14 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
max_len = max(max_len, lengths.max())
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||
expanded_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||
|
||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
||||
if pad_left:
|
||||
mask = expanded_lengths < (max_len - lengths).unsqueeze(1)
|
||||
else:
|
||||
mask = expanded_lengths >= lengths.unsqueeze(-1)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user