change data -> data_{dataset}

This commit is contained in:
sathvik udupa 2023-05-02 14:02:23 +05:30
parent 99487d3000
commit 92c3f61f21
10 changed files with 696 additions and 62 deletions

View File

@ -52,12 +52,11 @@ class _SeedWorkers:
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule:
class MUCSAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
DataModule for k2 ASR experiments.
It assumes there is always one train, valid dataloader, and one test loader
This modified from librispeech asrmodule
It contains all the common data pipeline modules used in ASR
experiments, e.g.:

View File

@ -38,7 +38,7 @@ import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import MUCSAsrDataModule
from conformer import Conformer
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
@ -687,7 +687,7 @@ def run(rank, world_size, args):
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
librispeech = MUCSAsrDataModule(args)
# params.full_libri = False
# if params.full_libri:
# train_cuts = librispeech.train_all_shuf_cuts()
@ -800,7 +800,7 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
MUCSAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/compile_hlg.py

167
egs/mucs/ASR/local/compile_hlg.py Executable file
View File

@ -0,0 +1,167 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
datapath = str(lang_dir).split('/')[0]
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{datapath}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{datapath}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{datapath}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{datapath}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

14
egs/mucs/ASR/local/compute_fbank_mucs.py Normal file → Executable file
View File

@ -59,6 +59,16 @@ def get_args():
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--manifestpath",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--fbankpath",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
return parser.parse_args()
@ -67,8 +77,8 @@ def compute_fbank_mucs(
bpe_model: Optional[str] = None,
dataset: Optional[str] = None,
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
src_dir = Path(args.manifestpath)
output_dir = Path(args.fbankpath)
num_jobs = min(48, os.cpu_count())
num_mel_bins = 80

View File

@ -0,0 +1,87 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
# Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
$shifted=0;
if ($ARGV[0] eq "--exclude") {
$exclude = 1;
shift @ARGV;
$shifted=1;
}
if ($ARGV[0] eq "-f") {
$field = $ARGV[1];
shift @ARGV; shift @ARGV;
$shifted=1
}
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
"Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
"Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
"only the lines that were *not* in id_list.\n" .
"Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
"If your older scripts (written before Oct 2014) stopped working and you used the\n" .
"-f option, add 1 to the argument.\n" .
"See also: utils/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
while(<>) {
$_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
# $1 is what we filter on.
if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
print $_;
}
}
} else {
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
@A >= $field || die "Invalid scp file line $_";
if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
print $_;
}
}
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)

View File

@ -0,0 +1,196 @@
#!/usr/bin/env bash
# Copyright 2010-2011 Microsoft Corporation
# 2012-2013 Johns Hopkins University (Author: Daniel Povey)
# Apache 2.0
# This script operates on a data directory, such as in data/train/.
# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data
# for what these directories contain.
# This script creates a subset of that data, consisting of some specified
# number of utterances. (The selected utterances are distributed evenly
# throughout the file, by the program ./subset_scp.pl).
# There are six options, none compatible with any other.
# If you give the --per-spk option, it will attempt to select the supplied
# number of utterances for each speaker (typically you would supply a much
# smaller number in this case).
# If you give the --speakers option, it selects a subset of n randomly
# selected speakers.
# If you give the --shortest option, it will give you the n shortest utterances.
# If you give the --first option, it will just give you the n first utterances.
# If you give the --last option, it will just give you the n last utterances.
# If you give the --spk-list or --utt-list option, it reads the
# speakers/utterances to keep from <speaker-list-file>/<utt-list-file>" (note,
# in this case there is no <num-utt> positional parameter; see usage message.)
shortest=false
perspk=false
speakers=false
first_opt=
spk_list=
utt_list=
expect_args=3
case $1 in
--first|--last) first_opt=$1; shift ;;
--per-spk) perspk=true; shift ;;
--shortest) shortest=true; shift ;;
--speakers) speakers=true; shift ;;
--spk-list) shift; spk_list=$1; shift; expect_args=2 ;;
--utt-list) shift; utt_list=$1; shift; expect_args=2 ;;
--*) echo "$0: invalid option '$1'"; exit 1
esac
if [ $# != $expect_args ]; then
echo "Usage:"
echo " subset_data_dir.sh [--speakers|--shortest|--first|--last|--per-spk] <srcdir> <num-utt> <destdir>"
echo " subset_data_dir.sh [--spk-list <speaker-list-file>] <srcdir> <destdir>"
echo " subset_data_dir.sh [--utt-list <utt-list-file>] <srcdir> <destdir>"
echo "By default, randomly selects <num-utt> utterances from the data directory."
echo "With --speakers, randomly selects enough speakers that we have <num-utt> utterances"
echo "With --per-spk, selects <num-utt> utterances per speaker, if available."
echo "With --first, selects the first <num-utt> utterances"
echo "With --last, selects the last <num-utt> utterances"
echo "With --shortest, selects the shortest <num-utt> utterances."
echo "With --spk-list, reads the speakers to keep from <speaker-list-file>"
echo "With --utt-list, reads the utterances to keep from <utt-list-file>"
exit 1;
fi
srcdir=$1
if [[ $spk_list || $utt_list ]]; then
numutt=
destdir=$2
else
numutt=$2
destdir=$3
fi
export LC_ALL=C
if [ ! -f $srcdir/utt2spk ]; then
echo "$0: no such file $srcdir/utt2spk"
exit 1
fi
if [[ $numutt && $numutt -gt $(wc -l <$srcdir/utt2spk) ]]; then
echo "$0: cannot subset to more utterances than you originally had."
exit 1
fi
if $shortest && [ ! -f $srcdir/feats.scp ]; then
echo "$0: you selected --shortest but no feats.scp exist."
exit 1
fi
mkdir -p $destdir || exit 1
if [[ $spk_list ]]; then
local/filter_scp.pl "$spk_list" $srcdir/spk2utt > $destdir/spk2utt || exit 1;
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk || exit 1;
elif [[ $utt_list ]]; then
local/filter_scp.pl "$utt_list" $srcdir/utt2spk > $destdir/utt2spk || exit 1;
local/utt2spk_to_spk2utt.pl < $destdir/utt2spk > $destdir/spk2utt || exit 1;
elif $speakers; then
utils/shuffle_list.pl < $srcdir/spk2utt |
awk -v numutt=$numutt '{ if (tot < numutt){ print; } tot += (NF-1); }' |
sort > $destdir/spk2utt
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk
elif $perspk; then
awk '{ n='$numutt'; printf("%s ",$1);
skip=1; while(n*(skip+1) <= NF-1) { skip++; }
for(x=2; x<=NF && x <= (n*skip+1); x += skip) { printf("%s ", $x); }
printf("\n"); }' <$srcdir/spk2utt >$destdir/spk2utt
utils/spk2utt_to_utt2spk.pl < $destdir/spk2utt > $destdir/utt2spk
else
if $shortest; then
# Select $numutt shortest utterances.
. ./path.sh
if [ -f $srcdir/utt2num_frames ]; then
ln -sf $(utils/make_absolute.sh $srcdir)/utt2num_frames $destdir/tmp.len
else
feat-to-len scp:$srcdir/feats.scp ark,t:$destdir/tmp.len || exit 1;
fi
sort -n -k2 $destdir/tmp.len |
awk '{print $1}' |
head -$numutt >$destdir/tmp.uttlist
local/filter_scp.pl $destdir/tmp.uttlist $srcdir/utt2spk >$destdir/utt2spk
rm $destdir/tmp.uttlist $destdir/tmp.len
else
# Select $numutt random utterances.
local/subset_scp.pl $first_opt $numutt $srcdir/utt2spk > $destdir/utt2spk || exit 1;
fi
local/utt2spk_to_spk2utt.pl < $destdir/utt2spk > $destdir/spk2utt
fi
# Perform filtering. utt2spk and spk2utt files already exist by this point.
# Filter by utterance.
[ -f $srcdir/feats.scp ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/feats.scp >$destdir/feats.scp
[ -f $srcdir/vad.scp ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp
[ -f $srcdir/utt2lang ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2lang >$destdir/utt2lang
[ -f $srcdir/utt2dur ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2dur >$destdir/utt2dur
[ -f $srcdir/utt2num_frames ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2num_frames >$destdir/utt2num_frames
[ -f $srcdir/utt2uniq ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2uniq >$destdir/utt2uniq
[ -f $srcdir/wav.scp ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp
[ -f $srcdir/utt2warp ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/utt2warp >$destdir/utt2warp
[ -f $srcdir/text ] &&
local/filter_scp.pl $destdir/utt2spk <$srcdir/text >$destdir/text
# Filter by speaker.
[ -f $srcdir/spk2warp ] &&
local/filter_scp.pl $destdir/spk2utt <$srcdir/spk2warp >$destdir/spk2warp
[ -f $srcdir/spk2gender ] &&
local/filter_scp.pl $destdir/spk2utt <$srcdir/spk2gender >$destdir/spk2gender
[ -f $srcdir/cmvn.scp ] &&
local/filter_scp.pl $destdir/spk2utt <$srcdir/cmvn.scp >$destdir/cmvn.scp
# Filter by recording-id.
if [ -f $srcdir/segments ]; then
local/filter_scp.pl $destdir/utt2spk <$srcdir/segments >$destdir/segments
# Recording-ids are in segments.
awk '{print $2}' $destdir/segments | sort | uniq >$destdir/reco
# The next line overrides the command above for wav.scp, which would be incorrect.
[ -f $srcdir/wav.scp ] &&
local/filter_scp.pl $destdir/reco <$srcdir/wav.scp >$destdir/wav.scp
else
# No segments; recording-ids are in wav.scp.
awk '{print $1}' $destdir/wav.scp | sort | uniq >$destdir/reco
fi
[ -f $srcdir/reco2file_and_channel ] &&
local/filter_scp.pl $destdir/reco <$srcdir/reco2file_and_channel >$destdir/reco2file_and_channel
[ -f $srcdir/reco2dur ] &&
local/filter_scp.pl $destdir/reco <$srcdir/reco2dur >$destdir/reco2dur
# Filter the STM file for proper sclite scoring.
# Copy over the comments from STM file.
[ -f $srcdir/stm ] &&
(grep "^;;" $srcdir/stm
local/filter_scp.pl $destdir/reco $srcdir/stm) >$destdir/stm
rm $destdir/reco
# Copy frame_shift if present.
[ -f $srcdir/frame_shift ] && cp $srcdir/frame_shift $destdir
srcutts=$(wc -l <$srcdir/utt2spk)
destutts=$(wc -l <$destdir/utt2spk)
echo "$0: reducing #utt from $srcutts to $destutts"
exit 0

105
egs/mucs/ASR/local/subset_scp.pl Executable file
View File

@ -0,0 +1,105 @@
#!/usr/bin/env perl
use warnings; #sed replacement for -w perl parameter
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program selects a subset of N elements in the scp.
# By default, it selects them evenly from throughout the scp, in order to avoid
# selecting too many from the same speaker. It prints them on the standard
# output.
# With the option --first, it just selects the N first utterances.
# With the option --last, it just selects the N last utterances.
# Last modified by JHU & HKUST @2013
$quiet = 0;
$first = 0;
$last = 0;
if (@ARGV > 0 && $ARGV[0] eq "--quiet") {
shift;
$quiet = 1;
}
if (@ARGV > 0 && $ARGV[0] eq "--first") {
shift;
$first = 1;
}
if (@ARGV > 0 && $ARGV[0] eq "--last") {
shift;
$last = 1;
}
if(@ARGV < 2 ) {
die "Usage: subset_scp.pl [--quiet][--first|--last] N in.scp\n" .
" --quiet causes it to not die if N < num lines in scp.\n" .
" --first and --last make it equivalent to head or tail.\n" .
"See also: filter_scp.pl\n";
}
$N = shift @ARGV;
if($N == 0) {
die "First command-line parameter to subset_scp.pl must be an integer, got \"$N\"";
}
$inscp = shift @ARGV;
open(I, "<$inscp") || die "Opening input scp file $inscp";
@F = ();
while(<I>) {
push @F, $_;
}
$numlines = @F;
if($N > $numlines) {
if ($quiet) {
$N = $numlines;
} else {
die "You requested from subset_scp.pl more elements than available: $N > $numlines";
}
}
sub select_n {
my ($start,$end,$num_needed) = @_;
my $diff = $end - $start;
if ($num_needed > $diff) {
die "select_n: code error";
}
if ($diff == 1 ) {
if ($num_needed > 0) {
print $F[$start];
}
} else {
my $halfdiff = int($diff/2);
my $halfneeded = int($num_needed/2);
select_n($start, $start+$halfdiff, $halfneeded);
select_n($start+$halfdiff, $end, $num_needed - $halfneeded);
}
}
if ( ! $first && ! $last) {
if ($N > 0) {
select_n(0, $numlines, $N);
}
} else {
if ($first) { # --first option: same as head.
for ($n = 0; $n < $N; $n++) {
print $F[$n];
}
} else { # --last option: same as tail.
for ($n = @F - $N; $n < @F; $n++) {
print $F[$n];
}
}
}

View File

@ -0,0 +1,38 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# converts an utt2spk file to a spk2utt file.
# Takes input from the stdin or from a file argument;
# output goes to the standard out.
if ( @ARGV > 1 ) {
die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
}
while(<>){
@A = split(" ", $_);
@A == 2 || die "Invalid line in utt2spk file: $_";
($u,$s) = @A;
if(!$seen_spk{$s}) {
$seen_spk{$s} = 1;
push @spklist, $s;
}
push (@{$spk_hash{$s}}, "$u");
}
foreach $s (@spklist) {
$l = join(' ',@{$spk_hash{$s}});
print "$s $l\n";
}

View File

@ -6,28 +6,30 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=60
stage=6
stage=9
stop_stage=9
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
# directories and files. download them from https://www.openslr.org/resources/104/
#
# - $dl_dir/hi-en
dl_dir=$PWD/download
espnet_path=/home/wtc7/espnet/egs2/MUCS/asr1/data/hi-en/
mkdir -p $dl_dir
raw_data_path="/data/Database/MUCS/"
dataset="bn-en" #hin-en or bn-en
datadir="data_"$dataset
raw_kaldi_files_path=$dl_dir/$dataset/
. shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy
vocab_size=400
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
mkdir -p $datadir
log() {
# This function is from espnet
@ -38,43 +40,73 @@ log() {
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: prepare LM files"
mkdir -p $dl_dir/lm
if [ ! -e $dl_dir/lm/.done ]; then
./local/prepare_lm_files.py --out-dir=$dl_dir/lm --data-path=$espnet_path --mode="train"
touch $dl_dir/lm/.done
fi
log "Stage -1: prepare data files"
mkdir -p $dl_dir/$dataset
for x in train dev test train_all; do
if [ -d "$dl_dir/$dataset/$x" ]; then rm -Rf $dl_dir/$dataset/$x; fi
done
mkdir -p $dl_dir/$dataset/{train,test,dev}
cp -r $raw_data_path/$dataset/"train"/"transcripts"/* $dl_dir/$dataset/"train"
cp -r $raw_data_path/$dataset/"test"/"transcripts"/* $dl_dir/$dataset/"test"
for x in train test
do
cp $dl_dir/$dataset/$x/"wav.scp" $dl_dir/$dataset/$x/"wav.scp_old"
cat $dl_dir/$dataset/$x/"wav.scp" | cut -d' ' -f1 > $dl_dir/$dataset/$x/wav_ids
cat $dl_dir/$dataset/$x/"wav.scp" | cut -d' ' -f2 | awk -v var="$raw_data_path/$dataset/$x/" '{print var$1}' > $dl_dir/$dataset/$x/wav_ids_with_fullpath
paste -d' ' $dl_dir/$dataset/$x/wav_ids $dl_dir/$dataset/$x/wav_ids_with_fullpath > $dl_dir/$dataset/$x/"wav.scp"
rm $dl_dir/$dataset/$x/wav_ids
rm $dl_dir/$dataset/$x/wav_ids_with_fullpath
done
./local/subset_data_dir.sh --first $dl_dir/$dataset/"train" 1000 $dl_dir/$dataset/"dev"
total=$(wc -l $dl_dir/$dataset/"train"/"text" | cut -d' ' -f1)
count=$(expr $total - 1000)
./local/subset_data_dir.sh --first $dl_dir/$dataset/"train" $count $dl_dir/$dataset/"train_reduced"
mv $dl_dir/$dataset/"train" $dl_dir/$dataset/"train_all"
mv $dl_dir/$dataset/"train_reduced" $dl_dir/$dataset/"train"
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
log "Stage 0: prepare LM files"
mkdir -p $raw_kaldi_files_path/lm
if [ ! -e $raw_kaldi_files_path/lm/.done ]; then
./local/prepare_lm_files.py --out-dir=$dl_dir/lm --data-path=$raw_kaldi_files_path --mode="train"
touch $raw_kaldi_files_path/lm/.done
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare MUCS manifest"
# We assume that you have downloaded the MUCS corpus
# to $dl_dir/
mkdir -p data/manifests
if [ ! -e data/manifests/.mucs.done ]; then
mkdir -p $datadir/manifests
if [ ! -e $datadir/manifests/.mucs.done ]; then
# generate lhotse manifests from kaldi style files
./local/prepare_manifest.py "$espnet_path" $nj data/manifests
./local/prepare_manifest.py "$raw_kaldi_files_path" $nj $datadir/manifests
touch data/manifests/.mucs.done
touch $datadir/manifests/.mucs.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for mucs"
mkdir -p data/fbank
if [ ! -e data/fbank/.mucs.done ]; then
./local/compute_fbank_mucs.py
touch data/fbank/.mucs.done
mkdir -p $datadir/fbank
if [ ! -e $datadir/fbank/.mucs.done ]; then
./local/compute_fbank_mucs.py --manifestpath $datadir/manifests/ --fbankpath $datadir/fbank
touch $datadir/fbank/.mucs.done
fi
# exit
if [ ! -e data/fbank/.mucs-validated.done ]; then
log "Validating data/fbank for mucs"
if [ ! -e $datadir/fbank/.mucs-validated.done ]; then
log "Validating $datadir/fbank for mucs"
parts=(
train
test
@ -82,9 +114,9 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
)
for part in ${parts[@]}; do
python3 ./local/validate_manifest.py \
data/fbank/mucs_cuts_${part}.jsonl.gz
$datadir/fbank/mucs_cuts_${part}.jsonl.gz
done
touch data/fbank/.mucs-validated.done
touch $datadir/fbank/.mucs-validated.done
fi
fi
@ -92,7 +124,7 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
lang_dir=data/lang_phone
lang_dir=$datadir/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
@ -124,11 +156,11 @@ fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare BPE based lang"
lang_dir=data/lang_bpe_${vocab_size}
lang_dir=$datadir/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir
cp $datadir/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
@ -172,7 +204,7 @@ fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Train LM from training data"
lang_dir=data/lang_bpe_${vocab_size}
lang_dir=$datadir/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/lm_3.arpa ]; then
./shared/make_kn_lm.py \
@ -195,37 +227,31 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
mkdir -p $datadir/lm
if [ ! -f $datadir/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--read-symbol-table="$datadir/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
data/lang_bpe_${vocab_size}/lm_3.arpa > data/lm/G_3_gram.fst.txt
$datadir/lang_bpe_${vocab_size}/lm_3.arpa > $datadir/lm/G_3_gram.fst.txt
fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
if [ ! -f $datadir/lm/G_4_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang_phone/words.txt" \
--read-symbol-table="$datadir/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
data/lang_bpe_${vocab_size}/lm_4.arpa > data/lm/G_4_gram.fst.txt
$datadir/lang_bpe_${vocab_size}/lm_4.arpa > $datadir/lm/G_4_gram.fst.txt
fi
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile HLG"
# ./local/compile_hlg.py --lang-dir data/lang_phone
# Note If ./local/compile_hlg.py throws OOM,
# please switch to the following command
#
# ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
lang_dir=data/lang_bpe_${vocab_size}
lang_dir=$datadir/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
fi

View File

@ -1,17 +1,24 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES="0"
set -e
dataset='bn-en'
datadir=data_"$dataset"
bpe=400
./conformer_ctc/train.py \
--num-epochs 60 \
--max-duration 300 \
--exp-dir ./conformer_ctc/exp_with_devset_split_bpe400 \
--lang-dir data/lang_bpe_400 \
--exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \
--manifest-dir $datadir/fbank \
--lang-dir $datadir/lang_bpe_"$bpe" \
--enable-musan False \
./conformer_ctc/decode.py \
--epoch 59 \
--epoch 60 \
--avg 10 \
--exp-dir ./conformer_ctc/exp_with_devset_split_bpe400 \
--manifest-dir $datadir/fbank \
--exp-dir ./conformer_ctc/exp_"$dataset"_bpe"$bpe" \
--max-duration 100 \
--lang-dir ./data/lang_bpe_400
--lang-dir $datadir/lang_bpe_"$bpe"