mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
change data -> data_{dataset}
This commit is contained in:
parent
99487d3000
commit
92c3f61f21
@ -52,12 +52,11 @@ class _SeedWorkers:
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
class MUCSAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train, valid dataloader, and one test loader
|
||||
This modified from librispeech asrmodule
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
|
@ -38,7 +38,7 @@ import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from asr_datamodule import MUCSAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
@ -687,7 +687,7 @@ def run(rank, world_size, args):
|
||||
if checkpoints:
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
librispeech = MUCSAsrDataModule(args)
|
||||
# params.full_libri = False
|
||||
# if params.full_libri:
|
||||
# train_cuts = librispeech.train_all_shuf_cuts()
|
||||
@ -800,7 +800,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
MUCSAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/local/compile_hlg.py
|
167
egs/mucs/ASR/local/compile_hlg.py
Executable file
167
egs/mucs/ASR/local/compile_hlg.py
Executable file
@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
lm:
|
||||
The language stem base name.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
datapath = str(lang_dir).split('/')[0]
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path(f"{datapath}/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"{datapath}/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"{datapath}/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), f"{datapath}/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir, args.lm)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
14
egs/mucs/ASR/local/compute_fbank_mucs.py
Normal file → Executable file
14
egs/mucs/ASR/local/compute_fbank_mucs.py
Normal file → Executable file
@ -59,6 +59,16 @@ def get_args():
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--manifestpath",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fbankpath",
|
||||
type=str,
|
||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -67,8 +77,8 @@ def compute_fbank_mucs(
|
||||
bpe_model: Optional[str] = None,
|
||||
dataset: Optional[str] = None,
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
src_dir = Path(args.manifestpath)
|
||||
output_dir = Path(args.fbankpath)
|
||||
num_jobs = min(48, os.cpu_count())
|
||||
num_mel_bins = 80
|
||||
|
||||
|
87
egs/mucs/ASR/local/filter_scp.pl
Executable file
87
egs/mucs/ASR/local/filter_scp.pl
Executable file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env perl
|
||||
# Copyright 2010-2012 Microsoft Corporation
|
||||
# Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This script takes a list of utterance-ids or any file whose first field
|
||||
# of each line is an utterance-id, and filters an scp
|
||||
# file (or any file whose "n-th" field is an utterance id), printing
|
||||
# out only those lines whose "n-th" field is in id_list. The index of
|
||||
# the "n-th" field is 1, by default, but can be changed by using
|
||||
# the -f <n> switch
|
||||
|
||||
$exclude = 0;
|
||||
$field = 1;
|
||||
$shifted = 0;
|
||||
|
||||
do {
|
||||
$shifted=0;
|
||||
if ($ARGV[0] eq "--exclude") {
|
||||
$exclude = 1;
|
||||
shift @ARGV;
|
||||
$shifted=1;
|
||||
}
|
||||
if ($ARGV[0] eq "-f") {
|
||||
$field = $ARGV[1];
|
||||
shift @ARGV; shift @ARGV;
|
||||
$shifted=1
|
||||
}
|
||||
} while ($shifted);
|
||||
|
||||
if(@ARGV < 1 || @ARGV > 2) {
|
||||
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
|
||||
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
|
||||
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
|
||||
"only the lines that were *not* in id_list.\n" .
|
||||
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
|
||||
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
|
||||
"-f option, add 1 to the argument.\n" .
|
||||
"See also: utils/filter_scp.pl .\n";
|
||||
}
|
||||
|
||||
|
||||
$idlist = shift @ARGV;
|
||||
open(F, "<$idlist") || die "Could not open id-list file $idlist";
|
||||
while(<F>) {
|
||||
@A = split;
|
||||
@A>=1 || die "Invalid id-list file line $_";
|
||||
$seen{$A[0]} = 1;
|
||||
}
|
||||
|
||||
if ($field == 1) { # Treat this as special case, since it is common.
|
||||
while(<>) {
|
||||
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
|
||||
# $1 is what we filter on.
|
||||
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
|
||||
print $_;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while(<>) {
|
||||
@A = split;
|
||||
@A > 0 || die "Invalid scp file line $_";
|
||||
@A >= $field || die "Invalid scp file line $_";
|
||||
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
|
||||
print $_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# tests:
|
||||
# the following should print "foo 1"
|
||||
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
|
||||
# the following should print "bar 2".
|
||||
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
|
196
egs/mucs/ASR/local/subset_data_dir.sh
Executable file
196
egs/mucs/ASR/local/subset_data_dir.sh
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
# 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
||||
# Apache 2.0
|
||||
|
||||
|
||||
# This script operates on a data directory, such as in data/train/.
|
||||
# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data
|
||||
# for what these directories contain.
|
||||
|
||||
# This script creates a subset of that data, consisting of some specified
|
||||
# number of utterances. (The selected utterances are distributed evenly
|
||||
# throughout the file, by the program ./subset_scp.pl).
|
||||
|
||||
# There are six options, none compatible with any other.
|
||||
|
||||
# If you give the --per-spk option, it will attempt to select the supplied
|
||||
# number of utterances for each speaker (typically you would supply a much
|
||||
# smaller number in this case).
|
||||
|
||||
# If you give the --speakers option, it selects a subset of n randomly
|
||||
# selected speakers.
|
||||
|
||||
# If you give the --shortest option, it will give you the n shortest utterances.
|
||||
|
||||
# If you give the --first option, it will just give you the n first utterances.
|
||||
|
||||
# If you give the --last option, it will just give you the n last utterances.
|
||||
|
||||
# If you give the --spk-list or --utt-list option, it reads the
|
||||
# speakers/utterances to keep from <speaker-list-file>/<utt-list-file>" (note,
|
||||
# in this case there is no <num-utt> positional parameter; see usage message.)
|
||||
|
||||
|
||||
shortest=false
|
||||
perspk=false
|
||||
speakers=false
|
||||
first_opt=
|
||||
spk_list=
|
||||
utt_list=
|
||||
|
||||
expect_args=3
|
||||
case $1 in
|
||||
--first|--last) first_opt=$1; shift ;;
|
||||
--per-spk) perspk=true; shift ;;
|
||||
--shortest) shortest=true; shift ;;
|
||||
--speakers) speakers=true; shift ;;
|
||||
--spk-list) shift; spk_list=$1; shift; expect_args=2 ;;
|
||||
--utt-list) shift; utt_list=$1; shift; expect_args=2 ;;
|
||||
--*) echo "$0: invalid option '$1'"; exit 1
|
||||
esac
|
||||
|
||||
if [ $# != $expect_args ]; then
|
||||
echo "Usage:"
|
||||
echo " subset_data_dir.sh [--speakers|--shortest|--first|--last|--per-spk] <srcdir> <num-utt> <destdir>"
|
||||
echo " subset_data_dir.sh [--spk-list <speaker-list-file>] <srcdir> <destdir>"
|
||||
echo " subset_data_dir.sh [--utt-list <utt-list-file>] <srcdir> <destdir>"
|
||||
echo "By default, randomly selects <num-utt> utterances from the data directory."
|
||||
echo "With --speakers, randomly selects enough speakers that we have <num-utt> utterances"
|
||||
echo "With --per-spk, selects <num-utt> utterances per speaker, if available."
|
||||
echo "With --first, selects the first <num-utt> utterances"
|
||||
echo "With --last, selects the last <num-utt> utterances"
|
||||
echo "With --shortest, selects the shortest <num-utt> utterances."
|
||||
echo "With --spk-list, reads the speakers to keep from <speaker-list-file>"
|
||||
echo "With --utt-list, reads the utterances to keep from <utt-list-file>"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
srcdir=$1
|
||||
if [[ $spk_list || $utt_list ]]; then
|
||||
numutt=
|
||||
destdir=$2
|
||||
else
|
||||
numutt=$2
|
||||
destdir=$3
|
||||
fi
|
||||
|
||||
export LC_ALL=C
|
||||
|
||||
if [ ! -f $srcdir/utt2spk ]; then
|
||||
echo "$0: no such file $srcdir/utt2spk"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $numutt && $numutt -gt $(wc -l <$srcdir/utt2spk) ]]; then
|
||||
echo "$0: cannot subset to more utterances than you originally had."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if $shortest && [ ! -f $srcdir/feats.scp ]; then
|
||||
echo "$0: you selected --shortest but no feats.scp exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p $destdir || exit 1
|
||||
|
||||
if [[ $spk_list ]]; then
|
||||
local/filter_scp.pl "$spk_list" $srcdir/spk2utt > $destdir/spk2utt || exit 1;
|
||||
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk || exit 1;
|
||||
elif [[ $utt_list ]]; then
|
||||
local/filter_scp.pl "$utt_list" $srcdir/utt2spk > $destdir/utt2spk || exit 1;
|
||||
local/utt2spk_to_spk2utt.pl < $destdir/utt2spk > $destdir/spk2utt || exit 1;
|
||||
elif $speakers; then
|
||||
utils/shuffle_list.pl < $srcdir/spk2utt |
|
||||
awk -v numutt=$numutt '{ if (tot < numutt){ print; } tot += (NF-1); }' |
|
||||
sort > $destdir/spk2utt
|
||||
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk
|
||||
elif $perspk; then
|
||||
awk '{ n='$numutt'; printf("%s ",$1);
|
||||
skip=1; while(n*(skip+1) <= NF-1) { skip++; }
|
||||
for(x=2; x<=NF && x <= (n*skip+1); x += skip) { printf("%s ", $x); }
|
||||
printf("\n"); }' <$srcdir/spk2utt >$destdir/spk2utt
|
||||
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk
|
||||
else
|
||||
if $shortest; then
|
||||
# Select $numutt shortest utterances.
|
||||
. ./path.sh
|
||||
if [ -f $srcdir/utt2num_frames ]; then
|
||||
ln -sf $(utils/make_absolute.sh $srcdir)/utt2num_frames $destdir/tmp.len
|
||||
else
|
||||
feat-to-len scp:$srcdir/feats.scp ark,t:$destdir/tmp.len || exit 1;
|
||||
fi
|
||||
sort -n -k2 $destdir/tmp.len |
|
||||
awk '{print $1}' |
|
||||
head -$numutt >$destdir/tmp.uttlist
|
||||
local/filter_scp.pl $destdir/tmp.uttlist $srcdir/utt2spk >$destdir/utt2spk
|
||||
rm $destdir/tmp.uttlist $destdir/tmp.len
|
||||
else
|
||||
# Select $numutt random utterances.
|
||||
local/subset_scp.pl $first_opt $numutt $srcdir/utt2spk > $destdir/utt2spk || exit 1;
|
||||
fi
|
||||
local/utt2spk_to_spk2utt.pl < $destdir/utt2spk > $destdir/spk2utt
|
||||
fi
|
||||
|
||||
# Perform filtering. utt2spk and spk2utt files already exist by this point.
|
||||
# Filter by utterance.
|
||||
[ -f $srcdir/feats.scp ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/feats.scp >$destdir/feats.scp
|
||||
[ -f $srcdir/vad.scp ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp
|
||||
[ -f $srcdir/utt2lang ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2lang >$destdir/utt2lang
|
||||
[ -f $srcdir/utt2dur ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2dur >$destdir/utt2dur
|
||||
[ -f $srcdir/utt2num_frames ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2num_frames >$destdir/utt2num_frames
|
||||
[ -f $srcdir/utt2uniq ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2uniq >$destdir/utt2uniq
|
||||
[ -f $srcdir/wav.scp ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp
|
||||
[ -f $srcdir/utt2warp ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2warp >$destdir/utt2warp
|
||||
[ -f $srcdir/text ] &&
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/text >$destdir/text
|
||||
|
||||
# Filter by speaker.
|
||||
[ -f $srcdir/spk2warp ] &&
|
||||
local/filter_scp.pl $destdir/spk2utt <$srcdir/spk2warp >$destdir/spk2warp
|
||||
[ -f $srcdir/spk2gender ] &&
|
||||
local/filter_scp.pl $destdir/spk2utt <$srcdir/spk2gender >$destdir/spk2gender
|
||||
[ -f $srcdir/cmvn.scp ] &&
|
||||
local/filter_scp.pl $destdir/spk2utt <$srcdir/cmvn.scp >$destdir/cmvn.scp
|
||||
|
||||
# Filter by recording-id.
|
||||
if [ -f $srcdir/segments ]; then
|
||||
local/filter_scp.pl $destdir/utt2spk <$srcdir/segments >$destdir/segments
|
||||
# Recording-ids are in segments.
|
||||
awk '{print $2}' $destdir/segments | sort | uniq >$destdir/reco
|
||||
# The next line overrides the command above for wav.scp, which would be incorrect.
|
||||
[ -f $srcdir/wav.scp ] &&
|
||||
local/filter_scp.pl $destdir/reco <$srcdir/wav.scp >$destdir/wav.scp
|
||||
else
|
||||
# No segments; recording-ids are in wav.scp.
|
||||
awk '{print $1}' $destdir/wav.scp | sort | uniq >$destdir/reco
|
||||
fi
|
||||
|
||||
[ -f $srcdir/reco2file_and_channel ] &&
|
||||
local/filter_scp.pl $destdir/reco <$srcdir/reco2file_and_channel >$destdir/reco2file_and_channel
|
||||
[ -f $srcdir/reco2dur ] &&
|
||||
local/filter_scp.pl $destdir/reco <$srcdir/reco2dur >$destdir/reco2dur
|
||||
|
||||
# Filter the STM file for proper sclite scoring.
|
||||
# Copy over the comments from STM file.
|
||||
[ -f $srcdir/stm ] &&
|
||||
(grep "^;;" $srcdir/stm
|
||||
local/filter_scp.pl $destdir/reco $srcdir/stm) >$destdir/stm
|
||||
|
||||
rm $destdir/reco
|
||||
|
||||
# Copy frame_shift if present.
|
||||
[ -f $srcdir/frame_shift ] && cp $srcdir/frame_shift $destdir
|
||||
|
||||
srcutts=$(wc -l <$srcdir/utt2spk)
|
||||
destutts=$(wc -l <$destdir/utt2spk)
|
||||
echo "$0: reducing #utt from $srcutts to $destutts"
|
||||
exit 0
|
105
egs/mucs/ASR/local/subset_scp.pl
Executable file
105
egs/mucs/ASR/local/subset_scp.pl
Executable file
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env perl
|
||||
use warnings; #sed replacement for -w perl parameter
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This program selects a subset of N elements in the scp.
|
||||
|
||||
# By default, it selects them evenly from throughout the scp, in order to avoid
|
||||
# selecting too many from the same speaker. It prints them on the standard
|
||||
# output.
|
||||
# With the option --first, it just selects the N first utterances.
|
||||
# With the option --last, it just selects the N last utterances.
|
||||
|
||||
# Last modified by JHU & HKUST @2013
|
||||
|
||||
|
||||
$quiet = 0;
|
||||
$first = 0;
|
||||
$last = 0;
|
||||
|
||||
if (@ARGV > 0 && $ARGV[0] eq "--quiet") {
|
||||
shift;
|
||||
$quiet = 1;
|
||||
}
|
||||
if (@ARGV > 0 && $ARGV[0] eq "--first") {
|
||||
shift;
|
||||
$first = 1;
|
||||
}
|
||||
if (@ARGV > 0 && $ARGV[0] eq "--last") {
|
||||
shift;
|
||||
$last = 1;
|
||||
}
|
||||
|
||||
if(@ARGV < 2 ) {
|
||||
die "Usage: subset_scp.pl [--quiet][--first|--last] N in.scp\n" .
|
||||
" --quiet causes it to not die if N < num lines in scp.\n" .
|
||||
" --first and --last make it equivalent to head or tail.\n" .
|
||||
"See also: filter_scp.pl\n";
|
||||
}
|
||||
|
||||
$N = shift @ARGV;
|
||||
if($N == 0) {
|
||||
die "First command-line parameter to subset_scp.pl must be an integer, got \"$N\"";
|
||||
}
|
||||
$inscp = shift @ARGV;
|
||||
open(I, "<$inscp") || die "Opening input scp file $inscp";
|
||||
|
||||
@F = ();
|
||||
while(<I>) {
|
||||
push @F, $_;
|
||||
}
|
||||
$numlines = @F;
|
||||
if($N > $numlines) {
|
||||
if ($quiet) {
|
||||
$N = $numlines;
|
||||
} else {
|
||||
die "You requested from subset_scp.pl more elements than available: $N > $numlines";
|
||||
}
|
||||
}
|
||||
|
||||
sub select_n {
|
||||
my ($start,$end,$num_needed) = @_;
|
||||
my $diff = $end - $start;
|
||||
if ($num_needed > $diff) {
|
||||
die "select_n: code error";
|
||||
}
|
||||
if ($diff == 1 ) {
|
||||
if ($num_needed > 0) {
|
||||
print $F[$start];
|
||||
}
|
||||
} else {
|
||||
my $halfdiff = int($diff/2);
|
||||
my $halfneeded = int($num_needed/2);
|
||||
select_n($start, $start+$halfdiff, $halfneeded);
|
||||
select_n($start+$halfdiff, $end, $num_needed - $halfneeded);
|
||||
}
|
||||
}
|
||||
|
||||
if ( ! $first && ! $last) {
|
||||
if ($N > 0) {
|
||||
select_n(0, $numlines, $N);
|
||||
}
|
||||
} else {
|
||||
if ($first) { # --first option: same as head.
|
||||
for ($n = 0; $n < $N; $n++) {
|
||||
print $F[$n];
|
||||
}
|
||||
} else { # --last option: same as tail.
|
||||
for ($n = @F - $N; $n < @F; $n++) {
|
||||
print $F[$n];
|
||||
}
|
||||
}
|
||||
}
|
38
egs/mucs/ASR/local/utt2spk_to_spk2utt.pl
Executable file
38
egs/mucs/ASR/local/utt2spk_to_spk2utt.pl
Executable file
@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# converts an utt2spk file to a spk2utt file.
|
||||
# Takes input from the stdin or from a file argument;
|
||||
# output goes to the standard out.
|
||||
|
||||
if ( @ARGV > 1 ) {
|
||||
die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
|
||||
}
|
||||
|
||||
while(<>){
|
||||
@A = split(" ", $_);
|
||||
@A == 2 || die "Invalid line in utt2spk file: $_";
|
||||
($u,$s) = @A;
|
||||
if(!$seen_spk{$s}) {
|
||||
$seen_spk{$s} = 1;
|
||||
push @spklist, $s;
|
||||
}
|
||||
push (@{$spk_hash{$s}}, "$u");
|
||||
}
|
||||
foreach $s (@spklist) {
|
||||
$l = join(' ',@{$spk_hash{$s}});
|
||||
print "$s $l\n";
|
||||
}
|
@ -6,28 +6,30 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
nj=60
|
||||
stage=6
|
||||
stage=9
|
||||
stop_stage=9
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
# directories and files. download them from https://www.openslr.org/resources/104/
|
||||
#
|
||||
# - $dl_dir/hi-en
|
||||
|
||||
dl_dir=$PWD/download
|
||||
espnet_path=/home/wtc7/espnet/egs2/MUCS/asr1/data/hi-en/
|
||||
mkdir -p $dl_dir
|
||||
|
||||
raw_data_path="/data/Database/MUCS/"
|
||||
dataset="bn-en" #hin-en or bn-en
|
||||
datadir="data_"$dataset
|
||||
raw_kaldi_files_path=$dl_dir/$dataset/
|
||||
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy
|
||||
vocab_size=400
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
mkdir -p $datadir
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
@ -38,43 +40,73 @@ log() {
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: prepare LM files"
|
||||
mkdir -p $dl_dir/lm
|
||||
if [ ! -e $dl_dir/lm/.done ]; then
|
||||
./local/prepare_lm_files.py --out-dir=$dl_dir/lm --data-path=$espnet_path --mode="train"
|
||||
touch $dl_dir/lm/.done
|
||||
fi
|
||||
log "Stage -1: prepare data files"
|
||||
|
||||
mkdir -p $dl_dir/$dataset
|
||||
for x in train dev test train_all; do
|
||||
if [ -d "$dl_dir/$dataset/$x" ]; then rm -Rf $dl_dir/$dataset/$x; fi
|
||||
done
|
||||
mkdir -p $dl_dir/$dataset/{train,test,dev}
|
||||
|
||||
|
||||
|
||||
cp -r $raw_data_path/$dataset/"train"/"transcripts"/* $dl_dir/$dataset/"train"
|
||||
cp -r $raw_data_path/$dataset/"test"/"transcripts"/* $dl_dir/$dataset/"test"
|
||||
|
||||
for x in train test
|
||||
do
|
||||
cp $dl_dir/$dataset/$x/"wav.scp" $dl_dir/$dataset/$x/"wav.scp_old"
|
||||
cat $dl_dir/$dataset/$x/"wav.scp" | cut -d' ' -f1 > $dl_dir/$dataset/$x/wav_ids
|
||||
cat $dl_dir/$dataset/$x/"wav.scp" | cut -d' ' -f2 | awk -v var="$raw_data_path/$dataset/$x/" '{print var$1}' > $dl_dir/$dataset/$x/wav_ids_with_fullpath
|
||||
paste -d' ' $dl_dir/$dataset/$x/wav_ids $dl_dir/$dataset/$x/wav_ids_with_fullpath > $dl_dir/$dataset/$x/"wav.scp"
|
||||
rm $dl_dir/$dataset/$x/wav_ids
|
||||
rm $dl_dir/$dataset/$x/wav_ids_with_fullpath
|
||||
done
|
||||
./local/subset_data_dir.sh --first $dl_dir/$dataset/"train" 1000 $dl_dir/$dataset/"dev"
|
||||
total=$(wc -l $dl_dir/$dataset/"train"/"text" | cut -d' ' -f1)
|
||||
count=$(expr $total - 1000)
|
||||
|
||||
./local/subset_data_dir.sh --first $dl_dir/$dataset/"train" $count $dl_dir/$dataset/"train_reduced"
|
||||
mv $dl_dir/$dataset/"train" $dl_dir/$dataset/"train_all"
|
||||
mv $dl_dir/$dataset/"train_reduced" $dl_dir/$dataset/"train"
|
||||
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
log "Stage 0: prepare LM files"
|
||||
mkdir -p $raw_kaldi_files_path/lm
|
||||
if [ ! -e $raw_kaldi_files_path/lm/.done ]; then
|
||||
./local/prepare_lm_files.py --out-dir=$dl_dir/lm --data-path=$raw_kaldi_files_path --mode="train"
|
||||
touch $raw_kaldi_files_path/lm/.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare MUCS manifest"
|
||||
# We assume that you have downloaded the MUCS corpus
|
||||
# to $dl_dir/
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.mucs.done ]; then
|
||||
mkdir -p $datadir/manifests
|
||||
if [ ! -e $datadir/manifests/.mucs.done ]; then
|
||||
# generate lhotse manifests from kaldi style files
|
||||
./local/prepare_manifest.py "$espnet_path" $nj data/manifests
|
||||
./local/prepare_manifest.py "$raw_kaldi_files_path" $nj $datadir/manifests
|
||||
|
||||
touch data/manifests/.mucs.done
|
||||
touch $datadir/manifests/.mucs.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank for mucs"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.mucs.done ]; then
|
||||
./local/compute_fbank_mucs.py
|
||||
touch data/fbank/.mucs.done
|
||||
mkdir -p $datadir/fbank
|
||||
if [ ! -e $datadir/fbank/.mucs.done ]; then
|
||||
./local/compute_fbank_mucs.py --manifestpath $datadir/manifests/ --fbankpath $datadir/fbank
|
||||
touch $datadir/fbank/.mucs.done
|
||||
fi
|
||||
|
||||
# exit
|
||||
|
||||
if [ ! -e data/fbank/.mucs-validated.done ]; then
|
||||
log "Validating data/fbank for mucs"
|
||||
if [ ! -e $datadir/fbank/.mucs-validated.done ]; then
|
||||
log "Validating $datadir/fbank for mucs"
|
||||
parts=(
|
||||
train
|
||||
test
|
||||
@ -82,9 +114,9 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
)
|
||||
for part in ${parts[@]}; do
|
||||
python3 ./local/validate_manifest.py \
|
||||
data/fbank/mucs_cuts_${part}.jsonl.gz
|
||||
$datadir/fbank/mucs_cuts_${part}.jsonl.gz
|
||||
done
|
||||
touch data/fbank/.mucs-validated.done
|
||||
touch $datadir/fbank/.mucs-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -92,7 +124,7 @@ fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
lang_dir=data/lang_phone
|
||||
lang_dir=$datadir/lang_phone
|
||||
mkdir -p $lang_dir
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
@ -124,11 +156,11 @@ fi
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare BPE based lang"
|
||||
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
lang_dir=$datadir/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp data/lang_phone/words.txt $lang_dir
|
||||
cp $datadir/lang_phone/words.txt $lang_dir
|
||||
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
log "Generate data for BPE training"
|
||||
@ -172,7 +204,7 @@ fi
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Train LM from training data"
|
||||
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
lang_dir=$datadir/lang_bpe_${vocab_size}
|
||||
|
||||
if [ ! -f $lang_dir/lm_3.arpa ]; then
|
||||
./shared/make_kn_lm.py \
|
||||
@ -195,37 +227,31 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
mkdir -p $datadir/lm
|
||||
if [ ! -f $datadir/lm/G_3_gram.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang_phone/words.txt" \
|
||||
--read-symbol-table="$datadir/lang_phone/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
data/lang_bpe_${vocab_size}/lm_3.arpa > data/lm/G_3_gram.fst.txt
|
||||
$datadir/lang_bpe_${vocab_size}/lm_3.arpa > $datadir/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
|
||||
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
|
||||
if [ ! -f $datadir/lm/G_4_gram.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang_phone/words.txt" \
|
||||
--read-symbol-table="$datadir/lang_phone/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
data/lang_bpe_${vocab_size}/lm_4.arpa > data/lm/G_4_gram.fst.txt
|
||||
$datadir/lang_bpe_${vocab_size}/lm_4.arpa > $datadir/lm/G_4_gram.fst.txt
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Compile HLG"
|
||||
# ./local/compile_hlg.py --lang-dir data/lang_phone
|
||||
|
||||
# Note If ./local/compile_hlg.py throws OOM,
|
||||
# please switch to the following command
|
||||
#
|
||||
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
|
||||
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
lang_dir=$datadir/lang_bpe_${vocab_size}
|
||||
./local/compile_hlg.py --lang-dir $lang_dir
|
||||
|
||||
fi
|
||||
|
@ -1,17 +1,24 @@
|
||||
#!/bin/bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
set -e
|
||||
dataset='bn-en'
|
||||
datadir=data_"$dataset"
|
||||
bpe=400
|
||||
|
||||
./conformer_ctc/train.py \
|
||||
--num-epochs 60 \
|
||||
--max-duration 300 \
|
||||
--exp-dir ./conformer_ctc/exp_with_devset_split_bpe400 \
|
||||
--lang-dir data/lang_bpe_400 \
|
||||
--exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \
|
||||
--manifest-dir $datadir/fbank \
|
||||
--lang-dir $datadir/lang_bpe_"$bpe" \
|
||||
--enable-musan False \
|
||||
|
||||
|
||||
./conformer_ctc/decode.py \
|
||||
--epoch 59 \
|
||||
--epoch 60 \
|
||||
--avg 10 \
|
||||
--exp-dir ./conformer_ctc/exp_with_devset_split_bpe400 \
|
||||
--manifest-dir $datadir/fbank \
|
||||
--exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \
|
||||
--max-duration 100 \
|
||||
--lang-dir ./data/lang_bpe_400
|
||||
--lang-dir $datadir/lang_bpe_"$bpe"
|
||||
|
Loading…
x
Reference in New Issue
Block a user