2023-08-14 09:51:20 +08:00

394 lines
12 KiB
Python

import argparse
import logging
import math
import warnings
from pathlib import Path
from typing import List
from tqdm import tqdm
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from lhotse import load_manifest
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from text_normalization import (
ref_text_normalization,
remove_non_alphabetic,
upper_only_alpha,
upper_all_char,
lower_all_char,
lower_only_alpha,
train_text_normalization,
)
from train_bert_encoder import (
add_model_arguments,
get_params,
get_tokenizer,
get_transducer_model,
)
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--manifest-dir",
type=str,
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
help="""This is the manfiest for long audio transcription.
It is intended to be sored, i.e first sort by recording ID and then sort by
start timestamp"""
)
parser.add_argument(
"--segment-length",
type=float,
default=30.0,
)
parser.add_argument(
"--use-pre-text",
type=str2bool,
default=False,
help="Whether use pre-text when decoding the current chunk"
)
parser.add_argument(
"--pre-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
default="mixed-punc",
help="The style of content prompt, i.e pre_text"
)
parser.add_argument(
"--num-history",
type=int,
default=2,
help="How many previous chunks to look if using pre-text for decoding"
)
parser.add_argument(
"--use-gt-pre-text",
type=str2bool,
default=False,
help="Whether use gt pre text when using content prompt",
)
add_model_arguments(parser)
return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
text (List[str]): Input text string
transform (str): Transform to be applied
Returns:
List[str]: _description_
"""
if transform == "mixed-punc":
return text
elif transform == "upper-no-punc":
return [upper_only_alpha(s) for s in text]
elif transform == "lower-no-punc":
return [lower_only_alpha(s) for s in text]
elif transform == "lower-punc":
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.res_dir = params.exp_dir / "long_audio_transcribe"
params.res_dir.mkdir(exist_ok=True)
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.method:
params.suffix += (
f"-{params.method}-beam-size-{params.beam_size}"
)
if params.use_pre_text:
if params.use_gt_pre_text:
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
else:
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}", log_level="info")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
tokenizer = get_tokenizer(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
# load manifest
manifest = load_manifest(params.manifest_dir)
results = []
count = 0
last_recording = ""
last_end = -1
history = []
for cut in manifest:
feat = cut.load_features()
feat_lens = cut.num_frames
cur_recording = cut.recording.id
if cur_recording != last_recording:
last_recording = cur_recording
history = [] # clean history
last_end = -1
else:
if cut.start < last_end: # overlap exits
logging.warning(f"An overlap exists between current cut and last cut")
# prepare input
x = torch.tensor(feat, device=device).unsqueeze(0)
x_lens = torch.tensor([feat_lens,], device=device)
if params.use_pre_text:
pre_texts = history[-params.num_history:]
pre_texts = [train_text_normalization(" ".join(pre_texts))]
if len(pre_texts) > 1000:
pre_texts = pre_texts[-1000:]
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
# encode pre_text
with warnings.catch_warnings():
warnings.simplefilter("ignore")
encoded_inputs = tokenizer(
pre_texts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=500,
).to(device)
memory, memory_key_padding_mask = model.encode_text(
encoded_inputs=encoded_inputs,
) # (T,B,C)
else:
memory = None
memory_key_padding_mask = None
with warnings.catch_warnings():
warnings.simplefilter("ignore")
encoder_out, encoder_out_lens = model.encode_audio(
feature=x,
feature_lens=x_lens,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
hyp = sp.decode(hyp_tokens)[0] # in string format
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
# extend the history, the history here is in original format
if params.use_gt_pre_text:
history.append(ref_text)
else:
history.append(hyp)
last_end = cut.end # update the last end timestamp
# append the current decoding result
ref = remove_non_alphabetic(ref_text.upper(), strict=True).split() # split
ref = [w for w in ref if w != ""]
hyp = remove_non_alphabetic(hyp.upper(), strict=True).split() # split
hyp = [w for w in hyp if w != ""]
results.append((cut.id, ref, hyp))
count += 1
if count % 100 == 0:
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
results = sorted(results)
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
if __name__=="__main__":
main()