Filter utterances with number_tokens > number_feature_frames. (#604)

This commit is contained in:
Fangjun Kuang 2022-11-12 07:57:58 +08:00 committed by GitHub
parent 2f43e4508b
commit e334e570d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 481 additions and 19 deletions

View File

@ -23,11 +23,15 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import sentencepiece as spm
import torch
from filter_cuts import filter_cuts
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
@ -41,12 +45,29 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_librispeech():
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to the bpe.model. If not None, we will remove short and
long utterances before extracting features""",
)
return parser.parse_args()
def compute_fbank_librispeech(bpe_model: Optional[str] = None):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
if bpe_model:
logging.info(f"Loading {bpe_model}")
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
dataset_parts = (
"dev-clean",
"dev-other",
@ -86,6 +107,9 @@ def compute_fbank_librispeech():
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if bpe_model:
cut_set = filter_cuts(cut_set, sp)
if "train" in partition:
cut_set = (
cut_set
@ -109,5 +133,6 @@ if __name__ == "__main__":
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_librispeech()
args = get_args()
logging.info(vars(args))
compute_fbank_librispeech(bpe_model=args.bpe_model)

View File

@ -0,0 +1,161 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script removes short and long utterances from a cutset.
Caution:
You may need to tune the thresholds for your own dataset.
Usage example:
python3 ./local/filter_cuts.py \
--bpe-model data/lang_bpe_500/bpe.model \
--in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \
--out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=Path,
help="Path to the bpe.model",
)
parser.add_argument(
"--in-cuts",
type=Path,
help="Path to the input cutset",
)
parser.add_argument(
"--out-cuts",
type=Path,
help="Path to the output cutset",
)
return parser.parse_args()
def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
total = 0 # number of total utterances before removal
removed = 0 # number of removed utterances
def remove_short_and_long_utterances(c: Cut):
"""Return False to exclude the input cut"""
nonlocal removed, total
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ./display_manifest_statistics.py
#
# You should use ./display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
total += 1
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
removed += 1
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./pruned_transducer_stateless2/conformer.py, the
# conv module uses the following expression
# for subsampling
if c.num_frames is None:
num_frames = c.duration * 100 # approximate
else:
num_frames = c.num_frames
T = ((num_frames - 1) // 2 - 1) // 2
# Note: for ./lstm_transducer_stateless/lstm.py, the formula is
# T = ((num_frames - 3) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
removed += 1
return False
return True
# We use to_eager() here so that we can print out the value of total
# and removed below.
ans = cut_set.filter(remove_short_and_long_utterances).to_eager()
ratio = removed / total * 100
logging.info(
f"Removed {removed} cuts from {total} cuts. "
f"{ratio:.3f}% data is removed."
)
return ans
def main():
args = get_args()
logging.info(vars(args))
if args.out_cuts.is_file():
logging.info(f"{args.out_cuts} already exists - skipping")
return
assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist"
assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist"
sp = spm.SentencePieceProcessor()
sp.load(str(args.bpe_model))
cut_set = load_manifest_lazy(args.in_cuts)
assert isinstance(cut_set, CutSet)
cut_set = filter_cuts(cut_set, sp)
logging.info(f"Saving to {args.out_cuts}")
args.out_cuts.parent.mkdir(parents=True, exist_ok=True)
cut_set.to_file(args.out_cuts)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -987,7 +987,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./lstm.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 3) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -991,7 +991,10 @@ def train_one_epoch(
params.best_train_loss = params.train_loss
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
def filter_short_and_long_utterances(
cuts: CutSet,
sp: spm.SentencePieceProcessor,
) -> CutSet:
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
@ -1001,7 +1004,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./lstm.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 3) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
cuts = cuts.filter(remove_short_and_long_utt)
@ -1104,7 +1134,7 @@ def run(rank, world_size, args):
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts)
train_cuts = filter_short_and_long_utterances(train_cuts, sp)
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
# XL 10k hours
@ -1121,7 +1151,7 @@ def run(rank, world_size, args):
logging.info("Using the S subset of GigaSpeech (250 hours)")
train_giga_cuts = gigaspeech.train_S_cuts()
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp)
train_giga_cuts = train_giga_cuts.repeat(times=None)
if args.enable_musan:

View File

@ -1007,7 +1007,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./lstm.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 3) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -906,7 +906,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./emformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -895,7 +895,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -961,7 +961,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -952,7 +952,10 @@ def train_one_epoch(
params.best_train_loss = params.train_loss
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
def filter_short_and_long_utterances(
cuts: CutSet,
sp: spm.SentencePieceProcessor,
) -> CutSet:
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
@ -962,7 +965,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
cuts = cuts.filter(remove_short_and_long_utt)
@ -1058,7 +1088,7 @@ def run(rank, world_size, args):
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts)
train_cuts = filter_short_and_long_utterances(train_cuts, sp)
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
# XL 10k hours
@ -1075,7 +1105,7 @@ def run(rank, world_size, args):
logging.info("Using the S subset of GigaSpeech (250 hours)")
train_giga_cuts = gigaspeech.train_S_cuts()
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp)
train_giga_cuts = train_giga_cuts.repeat(times=None)
if args.enable_musan:

View File

@ -1011,7 +1011,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -1043,7 +1043,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -1005,7 +1005,34 @@ def run(rank, world_size, args):
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./conformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)