diff --git a/egs/gigaspeech/ASR/conformer_ctc/pretrained.py b/egs/gigaspeech/ASR/conformer_ctc/pretrained.py index 4638a14a7..266f5506d 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/pretrained.py +++ b/egs/gigaspeech/ASR/conformer_ctc/pretrained.py @@ -210,8 +210,7 @@ def get_parser(): default="output", help=""" Output directory name, we store output hypothesis - """ - + """, ) parser.add_argument( @@ -220,7 +219,7 @@ def get_parser(): default=10, help=""" Number of input files in one batch, defaulted to 10 - """ + """, ) parser.add_argument( "sound_files", @@ -302,7 +301,7 @@ def decode_one_batch( feature: torch.Tensor, bpe_model: Optional[spm.SentencePieceProcessor] = None, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: +) -> List[List[str]]: device = decoding_graph.device assert feature.ndim == 3 @@ -404,11 +403,14 @@ def main(): params.update(vars(args)) su = ShortUUID(alphabet=string.ascii_lowercase + string.digits) - output_dir = Path(args.output_dir)/(su.random(length=5) + '-' + datetime.datetime.now().strftime( - "%Y-%m-%d_%H-%M")) + output_dir = Path(args.output_dir) / ( + su.random(length=5) + + "-" + + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") + ) output_dir.mkdir(exist_ok=True, parents=True) - with (output_dir/"command").open("w") as f_cmd: + with (output_dir / "command").open("w") as f_cmd: f_cmd.write(f"{args}\n") f_cmd.write(f"{params}\n") logging.info(f"{params}") @@ -502,23 +504,29 @@ def main(): else: raise ValueError(f"Unsupported decoding method: {params.method}") - testdata = TestDataset(features) - tl = DataLoader(testdata, batch_size=args.batch_size) + inputdata = TestDataset(features) + tl = DataLoader(inputdata, batch_size=args.batch_size) hyps = [] num_batches = len(tl) for batch_idx, batch in enumerate(tl): - hyps.extend(decode_one_batch( - params, model, decoding_graph, batch[0], bpe_model, G)) + hyps.extend( + decode_one_batch( + params, model, decoding_graph, batch[0], bpe_model, G + ) + ) logging.info( - f"batch {batch_idx + 1}/{num_batches}, cuts processed until now is {len(hyps)}") + f"batch {batch_idx + 1}/{num_batches}, cuts processed until now is {len(hyps)}" + ) logging.info(f"Writing hypothesis to output dir {output_dir}") s = "\n" for filename, hyp in zip(wave_names, hyps): words = " ".join(hyp) s += f"{filename}:\n{words}\n\n" - with (output_dir/os.path.basename(filename.replace(".wav", ".txt"))).open("w") as f_hyp: + with ( + output_dir / os.path.basename(filename.replace(".wav", ".txt")) + ).open("w") as f_hyp: f_hyp.write(words) logging.info(s) logging.info("Decoding Done") diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 9dd3c046d..bc9d28314 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -59,9 +59,15 @@ def get_parser(): "--num-splits", type=int, required=True, - help="The number of splits of the XL subset", + help="The number of splits of the subset", ) + parser.add_argument( + "--subset", + type=str, + default="XL", + help="subset name, XL, L, S, XS" + ) parser.add_argument( "--start", type=int, @@ -80,7 +86,7 @@ def get_parser(): def compute_fbank_gigaspeech_splits(args): num_splits = args.num_splits - output_dir = "data/fbank/XL_split" + output_dir = f"data/fbank/{args.subset}_split" output_dir = Path(output_dir) assert output_dir.exists(), f"{output_dir} does not exist!" @@ -103,12 +109,12 @@ def compute_fbank_gigaspeech_splits(args): idx = f"{i + 1}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") - cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz" + cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz" + raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz" logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -117,7 +123,7 @@ def compute_fbank_gigaspeech_splits(args): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{output_dir}/feats_XL_{idx}", + storage_path=f"{output_dir}/feats_{args.subset}_{idx}", num_workers=args.num_workers, batch_duration=args.batch_duration, ) diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 0cec82ad5..c227f0d49 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -47,11 +47,7 @@ def preprocess_giga_speech(): output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) - dataset_parts = ( - "DEV", - "TEST", - "XL", - ) + dataset_parts = "S" logging.info("Loading manifest (may take 4 minutes)") manifests = read_manifests_if_cached( @@ -86,16 +82,16 @@ def preprocess_giga_speech(): ) # Run data augmentation that needs to be done in the # time domain. - if partition not in ["DEV", "TEST"]: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 8 minutes and saving may take 20 minutes)" - ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + # if partition not in ["DEV", "TEST"]: + # logging.info( + # f"Speed perturb for {partition} with factors 0.9 and 1.1 " + # "(Perturbing may take 8 minutes and saving may take 20 minutes)" + # ) + # cut_set = ( + # cut_set + # + cut_set.perturb_speed(0.9) + # + cut_set.perturb_speed(1.1) + # ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index ce5116336..1edf86cb6 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -503,34 +503,37 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + load_checkpoint(f"{params.exp_dir}/pretrained-iter-3488000-avg-20.pt", model) + + # we don't average on models + # if params.iter > 0: + # filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + # : params.avg + # ] + # if len(filenames) == 0: + # raise ValueError( + # f"No checkpoints found for" + # f" --iter {params.iter}, --avg {params.avg}" + # ) + # elif len(filenames) < params.avg: + # raise ValueError( + # f"Not enough checkpoints ({len(filenames)}) found for" + # f" --iter {params.iter}, --avg {params.avg}" + # ) + # logging.info(f"averaging {filenames}") + # model.to(device) + # model.load_state_dict(average_checkpoints(filenames, device=device)) + # elif params.avg == 1: + # load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + # else: + # start = params.epoch - params.avg + 1 + # filenames = [] + # for i in range(start, params.epoch + 1): + # if start >= 0: + # filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + # logging.info(f"averaging {filenames}") + # model.to(device) + # model.load_state_dict(average_checkpoints(filenames, device=device)) model.to(device) model.eval() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/pretrained.py new file mode 100755 index 000000000..4e3587177 --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by +./pruned_transducer_stateless2/export.py +""" + +import os +import argparse +import logging +import math +import string +import datetime +from pathlib import Path +from typing import List +from icefall.utils import setup_logger + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader + +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model +from shortuuid import ShortUUID +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + parser.add_argument( + "--output", + type=str, + default="birch/output", + help="output directory name", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=10, + ) + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + wave_names = [] + + def loadfile(filename): + wave, sample_rate = torchaudio.load(filename) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + wave_names.append(str(filename)) + + for f in filenames: + file_path = Path(f) + if file_path.is_file(): + loadfile(file_path) + elif file_path.is_dir(): + for filename in file_path.iterdir(): + loadfile(filename) + else: + logging.error(f"{f} must be a filename or a dirname") + return ans, wave_names + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + features: torch.tensor, + sp: spm.SentencePieceProcessor, +) -> List[List[str]]: + + device = features.device + feature_lengths = [f.size(0) for f in features] + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.debug(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + return hyps + + +class TestDataset(Dataset): + def __init__(self, features: torch.Tensor): + self.features = features + + def __len__(self): + return len(self.features) + + def __getitem__(self, idx): + return (self.features[idx], 0) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + su = ShortUUID(alphabet=string.ascii_lowercase + string.digits) + + params.suffix = f"-pruned-transducer-stateless2-{params.method}" + if "fast_beam_search" in params.method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + output_dir = Path(params.output) / ( + su.random(length=5) + + "-" + + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") + + params.suffix + ) + output_dir.mkdir(exist_ok=True, parents=True) + setup_logger(f"{output_dir}/log-decode") + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(params.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves, wavnames = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + inputdata = TestDataset(features) + tl = DataLoader(inputdata, batch_size=params.batch_size) + + num_batches = len(tl) + hyps = [] + for batch_idx, batch in enumerate(tl): + hyps.extend(decode_one_batch(params, model, batch[0], sp)) + logging.info( + f"batch {batch_idx + 1}/{num_batches}, cuts processed until now is {len(hyps)}" + ) + + s = "\n" + assert len(wavnames) == len(hyps) + for filename, hyp in zip(wavnames, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + with ( + output_dir / os.path.basename(filename.replace(".wav", ".txt")) + ).open("w") as fhyp: + fhyp.write(words) + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main()