This commit is contained in:
marcoyang1998 2023-09-20 10:35:37 +08:00
parent 93461fb77e
commit cda6e06a85
6 changed files with 257 additions and 201 deletions

View File

@ -450,6 +450,7 @@ def decode_one_batch(
else:
pre_texts = ["" for _ in range(batch_size)]
# get the librispeech biasing data
if params.use_ls_context_list and params.use_ls_test_set:
if params.biasing_level == "utterance":
pre_texts = [biasing_dict[id] for id in cut_ids]
@ -476,7 +477,6 @@ def decode_one_batch(
# Get the text embedding
if params.use_pre_text or params.use_style_prompt:
# apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if not params.use_ls_context_list:

View File

@ -15,17 +15,18 @@
# limitations under the License.
import random
import warnings
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface
from scaling import ScaledLinear, penalize_abs_values_gt
from torch import Tensor
from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
from typing import Optional, Tuple
class Transducer(nn.Module):
@ -185,11 +186,6 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
@ -264,4 +260,3 @@ class Transducer(nn.Module):
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens

View File

@ -15,17 +15,18 @@
# limitations under the License.
import random
import warnings
from typing import Dict, Optional, Tuple
import k2
import torch
import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface
from scaling import ScaledLinear, penalize_abs_values_gt
from torch import Tensor
from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
from typing import Optional, Tuple, Dict
class PromptedTransducer(nn.Module):
@ -101,8 +102,16 @@ class PromptedTransducer(nn.Module):
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
self.context_fuser = context_fuser
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
assert text_encoder_type in (
"BERT",
"DistilBERT",
"BERT-UNCASED",
), f"Unseen text_encoder type {text_encoder_type}"
self.text_encoder_dim = (
self.text_encoder.config.hidden_size
if text_encoder_type in ("BERT", "BERT-UNCASED")
else self.text_encoder.config.dim
)
if text_encoder_adapter:
self.text_encoder_adapter = nn.Sequential(
@ -112,7 +121,9 @@ class PromptedTransducer(nn.Module):
else:
self.text_encoder_adapter = None
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
self.style_prompt_embedding = nn.Parameter(
torch.full((self.text_encoder_dim,), 0.5)
)
def forward(
self,
@ -184,8 +195,7 @@ class PromptedTransducer(nn.Module):
if use_pre_text:
memory, memory_key_padding_mask = self.encode_text(
encoded_inputs,
style_lens=style_lens
encoded_inputs, style_lens=style_lens
)
else:
memory = None
@ -231,11 +241,6 @@ class PromptedTransducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
@ -270,7 +275,7 @@ class PromptedTransducer(nn.Module):
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
if self.context_fuser is not None and memory is not None:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
context = self.joiner.context_proj(context)
else:
@ -304,13 +309,14 @@ class PromptedTransducer(nn.Module):
(memory_len, batch_size, embed_dim) = memory.shape
indicator = (
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
< style_lens
torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
)
indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory)
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(
memory_len, batch_size, self.text_encoder_dim
)
return memory + extra_term
@ -333,7 +339,7 @@ class PromptedTransducer(nn.Module):
# Freeze the pre-trained text encoder
with torch.no_grad():
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
memory = memory.permute(1,0,2)
memory = memory.permute(1, 0, 2)
# Text encoder adapter
if self.text_encoder_adapter is not None:

View File

@ -1,12 +1,29 @@
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def train_text_normalization(s: str) -> str:
# replace full-width with half-width
s = s.replace("", '"')
s = s.replace("", '"')
s = s.replace("", "'")
s = s.replace("", "'")
if s[:2] == "\" ": # remove the starting double quote
if s[:2] == '" ': # remove the starting double quote
s = s[2:]
return s
@ -17,8 +34,6 @@ def ref_text_normalization(ref_text: str) -> str:
p = r"[FN#[0-9]*]"
pattern = re.compile(p)
# ref_text = ref_text.replace("”", "\"")
# ref_text = ref_text.replace("", "'")
res = pattern.findall(ref_text)
ref_text = re.sub(p, "", ref_text)
@ -27,32 +42,34 @@ def ref_text_normalization(ref_text: str) -> str:
return ref_text
def remove_non_alphabetic(text: str, strict: bool=True) -> str:
def remove_non_alphabetic(text: str, strict: bool = True) -> str:
# Recommend to set strict to False
if not strict:
# Note, this also keeps space, single quote(') and hypen (-)
text = text.replace("-", " ")
text = text.replace("", " ")
return re.sub("[^a-zA-Z0-9\s']+", "", text)
return re.sub(r"[^a-zA-Z0-9\s']+", "", text)
else:
# only keeps space
return re.sub("[^a-zA-Z\s]+", "", text)
return re.sub(r"[^a-zA-Z\s]+", "", text)
def recog_text_normalization(recog_text: str) -> str:
pass
def upper_only_alpha(text: str) -> str:
return remove_non_alphabetic(text.upper(), strict=False)
def lower_only_alpha(text: str) -> str:
return remove_non_alphabetic(text.lower(), strict=False)
def lower_all_char(text: str) -> str:
return text.lower()
def upper_all_char(text: str) -> str:
return text.upper()
if __name__ == "__main__":
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
print(ref_text)

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,)
# Zengwei Yao)
# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang,
#
#
# See ../../../../LICENSE for clarification regarding multiple authors
#

View File

@ -1,18 +1,47 @@
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
python ./zipformer_prompt_asr/transcribe_bert.py \
--epoch 50 \
--avg 10 \
--exp-dir ./zipformer_prompt_asr/exp \
--manifest-dir data/long_audios/long_audio.jsonl.gz \
--pre-text-transform mixed-punc \
--style-text-transform mixed-punc \
--num-history 5 \
--use-pre-text True \
--use-gt-pre-text False
"""
import argparse
import logging
import math
import warnings
from pathlib import Path
from typing import List
from tqdm import tqdm
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from lhotse import load_manifest, Fbank
from beam_search import (
beam_search,
fast_beam_search_one_best,
@ -20,21 +49,24 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from decode_bert import _apply_style_transform
from lhotse import Fbank, load_manifest
from text_normalization import (
ref_text_normalization,
remove_non_alphabetic,
upper_only_alpha,
upper_all_char,
lower_all_char,
lower_only_alpha,
ref_text_normalization,
remove_non_alphabetic,
train_text_normalization,
upper_all_char,
upper_only_alpha,
)
from train_bert_encoder_with_style import (
from tqdm import tqdm
from train_bert_encoder import (
_encode_texts_as_bytes_with_tokenizer,
add_model_arguments,
get_params,
get_tokenizer,
get_transducer_model,
_encode_texts_as_bytes_with_tokenizer,
)
from icefall.checkpoint import (
@ -51,6 +83,7 @@ from icefall.utils import (
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -91,7 +124,6 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
@ -120,53 +152,47 @@ def get_parser():
parser.add_argument(
"--manifest-dir",
type=str,
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
default="data/long_audios/long_audio.jsonl.gz",
help="""This is the manfiest for long audio transcription.
It is intended to be sored, i.e first sort by recording ID and then sort by
start timestamp"""
)
parser.add_argument(
"--segment-length",
type=float,
default=30.0,
The cust are intended to be sorted, i.e first sort by recording ID and
then sort by start timestamp""",
)
parser.add_argument(
"--use-pre-text",
type=str2bool,
default=False,
help="Whether use pre-text when decoding the current chunk"
help="Whether use pre-text when decoding the current chunk",
)
parser.add_argument(
"--use-style-prompt",
type=str2bool,
default=True,
help="Use style prompt when evaluation"
help="Use style prompt when evaluation",
)
parser.add_argument(
"--pre-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc",
help="The style of content prompt, i.e pre_text"
help="The style of content prompt, i.e pre_text",
)
parser.add_argument(
"--style-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc",
help="The style of style prompt, i.e style_text"
help="The style of style prompt, i.e style_text",
)
parser.add_argument(
"--num-history",
type=int,
default=2,
help="How many previous chunks to look if using pre-text for decoding"
help="How many previous chunks to look if using pre-text for decoding",
)
parser.add_argument(
@ -186,28 +212,6 @@ def get_parser():
return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
text (List[str]): Input text string
transform (str): Transform to be applied
Returns:
List[str]: _description_
"""
if transform == "mixed-punc":
return text
elif transform == "upper-no-punc":
return [upper_only_alpha(s) for s in text]
elif transform == "lower-no-punc":
return [lower_only_alpha(s) for s in text]
elif transform == "lower-punc":
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
@torch.no_grad()
def main():
@ -236,19 +240,20 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.method:
params.suffix += (
f"-{params.method}-beam-size-{params.beam_size}"
)
params.suffix += f"-{params.method}-beam-size-{params.beam_size}"
if params.use_pre_text:
if params.use_gt_pre_text:
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
else:
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
params.suffix += (
f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
)
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info")
book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "")
setup_logger(
f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info"
)
logging.info("Decoding started")
device = torch.device("cpu")
@ -265,13 +270,12 @@ def main():
logging.info(f"Number of model parameters: {num_param}")
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
@ -334,29 +338,35 @@ def main():
feat = cut.compute_features(extractor=Fbank())
feat_lens = feat.shape[0]
cur_recording = cut.recording.id
if cur_recording != last_recording:
last_recording = cur_recording
history = [] # clean history
history = [] # clean up the history
last_end = -1
logging.info(f"Moving on to the next recording")
logging.info("Moving on to the next recording")
else:
if cut.start < last_end - 0.2: # overlap exits
logging.warning(f"An overlap exists between current cut and last cut")
if cut.start < last_end - 0.2: # overlap with the previous cuts
logging.warning("An overlap exists between current cut and last cut")
logging.warning("Skipping this cut!")
continue
if cut.start > last_end + 10:
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.")
logging.warning(
f"Large time gap between the current and previous utterance: {cut.start - last_end}."
)
# prepare input
x = torch.tensor(feat, device=device).unsqueeze(0)
x_lens = torch.tensor([feat_lens,], device=device)
x_lens = torch.tensor(
[
feat_lens,
],
device=device,
)
if params.use_pre_text:
if params.num_history > 0:
pre_texts = history[-params.num_history:]
pre_texts = history[-params.num_history :]
else:
pre_texts = []
num_pre_texts.append(len(pre_texts))
@ -366,9 +376,11 @@ def main():
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
style_texts = _apply_style_transform(
style_texts, params.style_text_transform
)
# encode pre_text
# encode prompts
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@ -380,7 +392,9 @@ def main():
no_limit=True,
)
if params.num_history > 5:
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
logging.info(
f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} "
)
memory, memory_key_padding_mask = model.encode_text(
encoded_inputs=encoded_inputs,
@ -414,9 +428,11 @@ def main():
)
hyp = sp.decode(hyp_tokens)[0] # in string format
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
ref_text = ref_text_normalization(
cut.supervisions[0].texts[0]
) # required to match the training
# extend the history, the history here is in original format
# extend the history
if params.use_gt_pre_text:
history.append(ref_text)
else:
@ -431,20 +447,32 @@ def main():
count += 1
if count % 100 == 0:
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}")
logging.info(
f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}"
)
logging.info(f"A total of {count} cuts")
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}")
logging.info(
f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}"
)
results = sorted(results)
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
recog_path = (
params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
errs_filename = (
params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
f,
f"long-audio-{params.method}",
results,
enable_log=True,
compute_CER=False,
)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
@ -457,19 +485,31 @@ def main():
id, ref, hyp = item
hyp = upper_only_alpha(" ".join(hyp)).split()
ref = upper_only_alpha(" ".join(ref)).split()
new_res.append((id,ref,hyp))
new_res.append((id, ref, hyp))
new_res = sorted(new_res)
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
recog_path = (
params.res_dir
/ f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
)
store_transcripts(filename=recog_path, texts=new_res)
logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
errs_filename = (
params.res_dir
/ f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False,
f,
f"long-audio-{params.method}",
new_res,
enable_log=True,
compute_CER=False,
)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
if __name__=="__main__":
if __name__ == "__main__":
main()