mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
add long audio transcription scripts
This commit is contained in:
parent
07e27348dd
commit
bbf1577818
475
egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py
Normal file
475
egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py
Normal file
@ -0,0 +1,475 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from lhotse import load_manifest, Fbank
|
||||
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
train_text_normalization,
|
||||
)
|
||||
from train_bert_encoder_with_style import (
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
|
||||
help="""This is the manfiest for long audio transcription.
|
||||
It is intended to be sored, i.e first sort by recording ID and then sort by
|
||||
start timestamp"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--segment-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use pre-text when decoding the current chunk"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-history",
|
||||
type=int,
|
||||
default=2,
|
||||
help="How many previous chunks to look if using pre-text for decoding"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-gt-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use gt pre text when using content prompt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
params.res_dir = params.exp_dir / "long_audio_transcribe"
|
||||
params.res_dir.mkdir(exist_ok=True)
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.method:
|
||||
params.suffix += (
|
||||
f"-{params.method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.use_gt_pre_text:
|
||||
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
else:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
|
||||
|
||||
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
|
||||
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
tokenizer = get_tokenizer(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
# load manifest
|
||||
manifest = load_manifest(params.manifest_dir)
|
||||
|
||||
results = []
|
||||
count = 0
|
||||
|
||||
last_recording = ""
|
||||
last_end = -1
|
||||
history = []
|
||||
num_pre_texts = []
|
||||
|
||||
for cut in manifest:
|
||||
if cut.has_features:
|
||||
feat = cut.load_features()
|
||||
feat_lens = cut.num_frames
|
||||
else:
|
||||
feat = cut.compute_features(extractor=Fbank())
|
||||
feat_lens = feat.shape[0]
|
||||
|
||||
|
||||
cur_recording = cut.recording.id
|
||||
|
||||
if cur_recording != last_recording:
|
||||
last_recording = cur_recording
|
||||
history = [] # clean history
|
||||
last_end = -1
|
||||
logging.info(f"Moving on to the next recording")
|
||||
else:
|
||||
if cut.start < last_end - 0.2: # overlap exits
|
||||
logging.warning(f"An overlap exists between current cut and last cut")
|
||||
logging.warning("Skipping this cut!")
|
||||
continue
|
||||
if cut.start > last_end + 10:
|
||||
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.")
|
||||
|
||||
# prepare input
|
||||
x = torch.tensor(feat, device=device).unsqueeze(0)
|
||||
x_lens = torch.tensor([feat_lens,], device=device)
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.num_history > 0:
|
||||
pre_texts = history[-params.num_history:]
|
||||
else:
|
||||
pre_texts = []
|
||||
num_pre_texts.append(len(pre_texts))
|
||||
pre_texts = [train_text_normalization(" ".join(pre_texts))]
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
style_texts = [fixed_sentence]
|
||||
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
|
||||
# encode pre_text
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||
pre_texts=pre_texts,
|
||||
style_texts=style_texts,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
no_limit=True,
|
||||
)
|
||||
if params.num_history > 5:
|
||||
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
encoded_inputs=encoded_inputs,
|
||||
style_lens=style_lens,
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=x,
|
||||
feature_lens=x_lens,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
if params.method == "greedy_search":
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
hyp = sp.decode(hyp_tokens)[0] # in string format
|
||||
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
|
||||
|
||||
# extend the history, the history here is in original format
|
||||
if params.use_gt_pre_text:
|
||||
history.append(ref_text)
|
||||
else:
|
||||
history.append(hyp)
|
||||
last_end = cut.end # update the last end timestamp
|
||||
|
||||
# append the current decoding result
|
||||
hyp = hyp.split()
|
||||
ref = ref_text.split()
|
||||
results.append((cut.id, ref, hyp))
|
||||
|
||||
count += 1
|
||||
if count % 100 == 0:
|
||||
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
|
||||
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}")
|
||||
|
||||
logging.info(f"A total of {count} cuts")
|
||||
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}")
|
||||
|
||||
results = sorted(results)
|
||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
if params.post_normalization:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
new_res = []
|
||||
for item in results:
|
||||
id, ref, hyp = item
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_res.append((id,ref,hyp))
|
||||
|
||||
new_res = sorted(new_res)
|
||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
store_transcripts(filename=recog_path, texts=new_res)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
if __name__=="__main__":
|
||||
main()
|
477
egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_subformer.py
Normal file
477
egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_subformer.py
Normal file
@ -0,0 +1,477 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from lhotse import load_manifest, Fbank
|
||||
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
train_text_normalization,
|
||||
)
|
||||
from train_subformer_with_style import (
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
_encode_text_as_tokens,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
|
||||
help="""This is the manfiest for long audio transcription.
|
||||
It is intended to be sored, i.e first sort by recording ID and then sort by
|
||||
start timestamp"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--segment-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use pre-text when decoding the current chunk"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-history",
|
||||
type=int,
|
||||
default=2,
|
||||
help="How many previous chunks to look if using pre-text for decoding"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-gt-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use gt pre text when using content prompt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
params.res_dir = params.exp_dir / "long_audio_transcribe"
|
||||
params.res_dir.mkdir(exist_ok=True)
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.method:
|
||||
params.suffix += (
|
||||
f"-{params.method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.use_gt_pre_text:
|
||||
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
else:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
|
||||
|
||||
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
|
||||
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
text_sp = spm.SentencePieceProcessor()
|
||||
text_sp.load(params.text_encoder_bpe_model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
# load manifest
|
||||
manifest = load_manifest(params.manifest_dir)
|
||||
|
||||
results = []
|
||||
count = 0
|
||||
|
||||
last_recording = ""
|
||||
last_end = -1
|
||||
history = []
|
||||
num_pre_texts = []
|
||||
|
||||
for cut in manifest:
|
||||
if cut.has_features:
|
||||
feat = cut.load_features()
|
||||
feat_lens = cut.num_frames
|
||||
else:
|
||||
feat = cut.compute_features(extractor=Fbank())
|
||||
feat_lens = feat.shape[0]
|
||||
|
||||
|
||||
cur_recording = cut.recording.id
|
||||
|
||||
if cur_recording != last_recording:
|
||||
last_recording = cur_recording
|
||||
history = [] # clean history
|
||||
last_end = -1
|
||||
logging.info(f"Moving on to the next recording")
|
||||
else:
|
||||
if cut.start < last_end - 0.2: # overlap exits
|
||||
logging.warning(f"An overlap exists between current cut and last cut")
|
||||
logging.warning("Skipping this cut!")
|
||||
continue
|
||||
if cut.start > last_end + 10:
|
||||
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.")
|
||||
|
||||
# prepare input
|
||||
x = torch.tensor(feat, device=device).unsqueeze(0)
|
||||
x_lens = torch.tensor([feat_lens,], device=device)
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.num_history > 0:
|
||||
pre_texts = history[-params.num_history:]
|
||||
else:
|
||||
pre_texts = []
|
||||
assert len(pre_texts) <= params.num_history
|
||||
num_pre_texts.append(len(pre_texts))
|
||||
pre_texts = [train_text_normalization(" ".join(pre_texts))]
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
style_texts = [fixed_sentence]
|
||||
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
|
||||
# encode pre_text
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens(
|
||||
pre_texts=pre_texts,
|
||||
style_texts=style_texts,
|
||||
bpe_model=text_sp,
|
||||
device=device,
|
||||
max_tokens=1500,
|
||||
)
|
||||
if params.num_history > 5:
|
||||
logging.info(f"Shape of encoded texts: {pre_texts.shape} ")
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
text=pre_texts,
|
||||
style_lens=style_text_lens,
|
||||
text_lens=pre_texts_lens,
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=x,
|
||||
feature_lens=x_lens,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
if params.method == "greedy_search":
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
hyp = sp.decode(hyp_tokens)[0] # in string format
|
||||
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
|
||||
|
||||
# extend the history, the history here is in original format
|
||||
if params.use_gt_pre_text:
|
||||
history.append(ref_text)
|
||||
else:
|
||||
history.append(hyp)
|
||||
last_end = cut.end # update the last end timestamp
|
||||
|
||||
# append the current decoding result
|
||||
hyp = hyp.split()
|
||||
ref = ref_text.split()
|
||||
results.append((cut.id, ref, hyp))
|
||||
|
||||
count += 1
|
||||
if count % 100 == 0:
|
||||
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
|
||||
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}")
|
||||
logging.info(f"A total of {count} cuts")
|
||||
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}")
|
||||
|
||||
results = sorted(results)
|
||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
if params.post_normalization:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
new_res = []
|
||||
for item in results:
|
||||
id, ref, hyp = item
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_res.append((id,ref,hyp))
|
||||
|
||||
new_res = sorted(new_res)
|
||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
store_transcripts(filename=recog_path, texts=new_res)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
if __name__=="__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user