This commit is contained in:
marcoyang1998 2023-09-20 10:35:37 +08:00
parent 93461fb77e
commit cda6e06a85
6 changed files with 257 additions and 201 deletions

View File

@ -450,6 +450,7 @@ def decode_one_batch(
else: else:
pre_texts = ["" for _ in range(batch_size)] pre_texts = ["" for _ in range(batch_size)]
# get the librispeech biasing data
if params.use_ls_context_list and params.use_ls_test_set: if params.use_ls_context_list and params.use_ls_test_set:
if params.biasing_level == "utterance": if params.biasing_level == "utterance":
pre_texts = [biasing_dict[id] for id in cut_ids] pre_texts = [biasing_dict[id] for id in cut_ids]
@ -476,7 +477,6 @@ def decode_one_batch(
# Get the text embedding # Get the text embedding
if params.use_pre_text or params.use_style_prompt: if params.use_pre_text or params.use_style_prompt:
# apply style transform to the pre_text and style_text # apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if not params.use_ls_context_list: if not params.use_ls_context_list:

View File

@ -15,17 +15,18 @@
# limitations under the License. # limitations under the License.
import random
import warnings
from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear, penalize_abs_values_gt
from torch import Tensor
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
from typing import Optional, Tuple
class Transducer(nn.Module): class Transducer(nn.Module):
@ -185,11 +186,6 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
@ -257,11 +253,10 @@ class Transducer(nn.Module):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder( encoder_out, encoder_out_lens = self.encoder(
x=x, x=x,
x_lens=x_lens, x_lens=x_lens,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
return encoder_out, encoder_out_lens

View File

@ -15,17 +15,18 @@
# limitations under the License. # limitations under the License.
import random
import warnings
from typing import Dict, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear, penalize_abs_values_gt
from torch import Tensor
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
from typing import Optional, Tuple, Dict
class PromptedTransducer(nn.Module): class PromptedTransducer(nn.Module):
@ -97,13 +98,21 @@ class PromptedTransducer(nn.Module):
vocab_size, vocab_size,
initial_scale=0.25, initial_scale=0.25,
) )
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
self.context_fuser = context_fuser self.context_fuser = context_fuser
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}" assert text_encoder_type in (
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim "BERT",
"DistilBERT",
"BERT-UNCASED",
), f"Unseen text_encoder type {text_encoder_type}"
self.text_encoder_dim = (
self.text_encoder.config.hidden_size
if text_encoder_type in ("BERT", "BERT-UNCASED")
else self.text_encoder.config.dim
)
if text_encoder_adapter: if text_encoder_adapter:
self.text_encoder_adapter = nn.Sequential( self.text_encoder_adapter = nn.Sequential(
nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False), nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False),
@ -111,8 +120,10 @@ class PromptedTransducer(nn.Module):
) )
else: else:
self.text_encoder_adapter = None self.text_encoder_adapter = None
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5)) self.style_prompt_embedding = nn.Parameter(
torch.full((self.text_encoder_dim,), 0.5)
)
def forward( def forward(
self, self,
@ -181,11 +192,10 @@ class PromptedTransducer(nn.Module):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# freeze the BERT text encoder # freeze the BERT text encoder
if use_pre_text: if use_pre_text:
memory, memory_key_padding_mask = self.encode_text( memory, memory_key_padding_mask = self.encode_text(
encoded_inputs, encoded_inputs, style_lens=style_lens
style_lens=style_lens
) )
else: else:
memory = None memory = None
@ -231,11 +241,6 @@ class PromptedTransducer(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
@ -270,12 +275,12 @@ class PromptedTransducer(nn.Module):
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
if self.context_fuser is not None and memory is not None: if self.context_fuser is not None and memory is not None:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C) memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask) context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
context = self.joiner.context_proj(context) context = self.joiner.context_proj(context)
else: else:
context = None context = None
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False) logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
@ -304,16 +309,17 @@ class PromptedTransducer(nn.Module):
(memory_len, batch_size, embed_dim) = memory.shape (memory_len, batch_size, embed_dim) = memory.shape
indicator = ( indicator = (
torch.arange(memory_len, device=memory.device).unsqueeze(-1) torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
< style_lens
) )
indicator = indicator.to(memory.dtype) indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory) extra_term = torch.zeros_like(memory)
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim) extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(
memory_len, batch_size, self.text_encoder_dim
)
return memory + extra_term return memory + extra_term
def encode_text( def encode_text(
self, self,
encoded_inputs: Dict, encoded_inputs: Dict,
@ -326,25 +332,25 @@ class PromptedTransducer(nn.Module):
Returns: Returns:
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
text_encoder and the attention mask text_encoder and the attention mask
""" """
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
# Freeze the pre-trained text encoder # Freeze the pre-trained text encoder
with torch.no_grad(): with torch.no_grad():
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C) memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
memory = memory.permute(1,0,2) memory = memory.permute(1, 0, 2)
# Text encoder adapter # Text encoder adapter
if self.text_encoder_adapter is not None: if self.text_encoder_adapter is not None:
memory = self.text_encoder_adapter(memory) memory = self.text_encoder_adapter(memory)
memory = self._add_style_indicator(memory, style_lens) memory = self._add_style_indicator(memory, style_lens)
memory_key_padding_mask = make_pad_mask(text_lens) memory_key_padding_mask = make_pad_mask(text_lens)
return memory, memory_key_padding_mask return memory, memory_key_padding_mask
def encode_audio( def encode_audio(
self, self,
feature: Tensor, feature: Tensor,
@ -368,14 +374,14 @@ class PromptedTransducer(nn.Module):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder( encoder_out, encoder_out_lens = self.encoder(
x=x, x=x,
x_lens=x_lens, x_lens=x_lens,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens

View File

@ -1,12 +1,29 @@
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re import re
def train_text_normalization(s: str) -> str: def train_text_normalization(s: str) -> str:
# replace full-width with half-width
s = s.replace("", '"') s = s.replace("", '"')
s = s.replace("", '"') s = s.replace("", '"')
s = s.replace("", "'") s = s.replace("", "'")
s = s.replace("", "'") s = s.replace("", "'")
if s[:2] == "\" ": # remove the starting double quote if s[:2] == '" ': # remove the starting double quote
s = s[2:] s = s[2:]
return s return s
@ -17,42 +34,42 @@ def ref_text_normalization(ref_text: str) -> str:
p = r"[FN#[0-9]*]" p = r"[FN#[0-9]*]"
pattern = re.compile(p) pattern = re.compile(p)
# ref_text = ref_text.replace("”", "\"")
# ref_text = ref_text.replace("", "'")
res = pattern.findall(ref_text) res = pattern.findall(ref_text)
ref_text = re.sub(p, "", ref_text) ref_text = re.sub(p, "", ref_text)
ref_text = train_text_normalization(ref_text) ref_text = train_text_normalization(ref_text)
return ref_text return ref_text
def remove_non_alphabetic(text: str, strict: bool=True) -> str: def remove_non_alphabetic(text: str, strict: bool = True) -> str:
# Recommend to set strict to False
if not strict: if not strict:
# Note, this also keeps space, single quote(') and hypen (-) # Note, this also keeps space, single quote(') and hypen (-)
text = text.replace("-", " ") text = text.replace("-", " ")
text = text.replace("", " ") text = text.replace("", " ")
return re.sub("[^a-zA-Z0-9\s']+", "", text) return re.sub(r"[^a-zA-Z0-9\s']+", "", text)
else: else:
# only keeps space # only keeps space
return re.sub("[^a-zA-Z\s]+", "", text) return re.sub(r"[^a-zA-Z\s]+", "", text)
def recog_text_normalization(recog_text: str) -> str:
pass
def upper_only_alpha(text: str) -> str: def upper_only_alpha(text: str) -> str:
return remove_non_alphabetic(text.upper(), strict=False) return remove_non_alphabetic(text.upper(), strict=False)
def lower_only_alpha(text: str) -> str: def lower_only_alpha(text: str) -> str:
return remove_non_alphabetic(text.lower(), strict=False) return remove_non_alphabetic(text.lower(), strict=False)
def lower_all_char(text: str) -> str: def lower_all_char(text: str) -> str:
return text.lower() return text.lower()
def upper_all_char(text: str) -> str: def upper_all_char(text: str) -> str:
return text.upper() return text.upper()
if __name__ == "__main__": if __name__ == "__main__":
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
print(ref_text) print(ref_text)

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang,
# Wei Kang, #
# Mingshuang Luo,)
# Zengwei Yao)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #

View File

@ -1,18 +1,47 @@
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
python ./zipformer_prompt_asr/transcribe_bert.py \
--epoch 50 \
--avg 10 \
--exp-dir ./zipformer_prompt_asr/exp \
--manifest-dir data/long_audios/long_audio.jsonl.gz \
--pre-text-transform mixed-punc \
--style-text-transform mixed-punc \
--num-history 5 \
--use-pre-text True \
--use-gt-pre-text False
"""
import argparse import argparse
import logging import logging
import math import math
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from tqdm import tqdm
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from lhotse import load_manifest, Fbank
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search_one_best,
@ -20,21 +49,24 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from decode_bert import _apply_style_transform
from lhotse import Fbank, load_manifest
from text_normalization import ( from text_normalization import (
ref_text_normalization,
remove_non_alphabetic,
upper_only_alpha,
upper_all_char,
lower_all_char, lower_all_char,
lower_only_alpha, lower_only_alpha,
ref_text_normalization,
remove_non_alphabetic,
train_text_normalization, train_text_normalization,
upper_all_char,
upper_only_alpha,
) )
from train_bert_encoder_with_style import ( from tqdm import tqdm
from train_bert_encoder import (
_encode_texts_as_bytes_with_tokenizer,
add_model_arguments, add_model_arguments,
get_params, get_params,
get_tokenizer, get_tokenizer,
get_transducer_model, get_transducer_model,
_encode_texts_as_bytes_with_tokenizer,
) )
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -51,11 +83,12 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
@ -74,7 +107,7 @@ def get_parser():
You can specify --avg to use more checkpoints for model averaging. You can specify --avg to use more checkpoints for model averaging.
""", """,
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
@ -83,22 +116,21 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'",
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless7/exp", default="pruned_transducer_stateless7/exp",
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/bpe.model",
help="""Path to bpe.model.""", help="""Path to bpe.model.""",
) )
parser.add_argument( parser.add_argument(
"--method", "--method",
type=str, type=str,
@ -110,104 +142,76 @@ def get_parser():
- fast_beam_search - fast_beam_search
""", """,
) )
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
) )
parser.add_argument( parser.add_argument(
"--manifest-dir", "--manifest-dir",
type=str, type=str,
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz", default="data/long_audios/long_audio.jsonl.gz",
help="""This is the manfiest for long audio transcription. help="""This is the manfiest for long audio transcription.
It is intended to be sored, i.e first sort by recording ID and then sort by The cust are intended to be sorted, i.e first sort by recording ID and
start timestamp""" then sort by start timestamp""",
) )
parser.add_argument(
"--segment-length",
type=float,
default=30.0,
)
parser.add_argument( parser.add_argument(
"--use-pre-text", "--use-pre-text",
type=str2bool, type=str2bool,
default=False, default=False,
help="Whether use pre-text when decoding the current chunk" help="Whether use pre-text when decoding the current chunk",
) )
parser.add_argument( parser.add_argument(
"--use-style-prompt", "--use-style-prompt",
type=str2bool, type=str2bool,
default=True, default=True,
help="Use style prompt when evaluation" help="Use style prompt when evaluation",
) )
parser.add_argument( parser.add_argument(
"--pre-text-transform", "--pre-text-transform",
type=str, type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc", default="mixed-punc",
help="The style of content prompt, i.e pre_text" help="The style of content prompt, i.e pre_text",
) )
parser.add_argument( parser.add_argument(
"--style-text-transform", "--style-text-transform",
type=str, type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
default="mixed-punc", default="mixed-punc",
help="The style of style prompt, i.e style_text" help="The style of style prompt, i.e style_text",
) )
parser.add_argument( parser.add_argument(
"--num-history", "--num-history",
type=int, type=int,
default=2, default=2,
help="How many previous chunks to look if using pre-text for decoding" help="How many previous chunks to look if using pre-text for decoding",
) )
parser.add_argument( parser.add_argument(
"--use-gt-pre-text", "--use-gt-pre-text",
type=str2bool, type=str2bool,
default=False, default=False,
help="Whether use gt pre text when using content prompt", help="Whether use gt pre text when using content prompt",
) )
parser.add_argument( parser.add_argument(
"--post-normalization", "--post-normalization",
type=str2bool, type=str2bool,
default=True, default=True,
) )
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
text (List[str]): Input text string
transform (str): Transform to be applied
Returns:
List[str]: _description_
"""
if transform == "mixed-punc":
return text
elif transform == "upper-no-punc":
return [upper_only_alpha(s) for s in text]
elif transform == "lower-no-punc":
return [lower_only_alpha(s) for s in text]
elif transform == "lower-punc":
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
@torch.no_grad() @torch.no_grad()
def main(): def main():
@ -216,7 +220,7 @@ def main():
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
@ -226,7 +230,7 @@ def main():
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
params.res_dir = params.exp_dir / "long_audio_transcribe" params.res_dir = params.exp_dir / "long_audio_transcribe"
params.res_dir.mkdir(exist_ok=True) params.res_dir.mkdir(exist_ok=True)
@ -234,21 +238,22 @@ def main():
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.method: if "beam_search" in params.method:
params.suffix += ( params.suffix += f"-{params.method}-beam-size-{params.beam_size}"
f"-{params.method}-beam-size-{params.beam_size}"
)
if params.use_pre_text: if params.use_pre_text:
if params.use_gt_pre_text: if params.use_gt_pre_text:
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}" params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
else: else:
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}" params.suffix += (
f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
)
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info") book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "")
setup_logger(
f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info"
)
logging.info("Decoding started") logging.info("Decoding started")
device = torch.device("cpu") device = torch.device("cpu")
@ -265,13 +270,12 @@ def main():
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
f" --iter {params.iter}, --avg {params.avg}"
) )
elif len(filenames) < params.avg + 1: elif len(filenames) < params.avg + 1:
raise ValueError( raise ValueError(
@ -310,22 +314,22 @@ def main():
device=device, device=device,
) )
) )
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device
# load manifest # load manifest
manifest = load_manifest(params.manifest_dir) manifest = load_manifest(params.manifest_dir)
results = [] results = []
count = 0 count = 0
last_recording = "" last_recording = ""
last_end = -1 last_end = -1
history = [] history = []
num_pre_texts = [] num_pre_texts = []
for cut in manifest: for cut in manifest:
if cut.has_features: if cut.has_features:
feat = cut.load_features() feat = cut.load_features()
@ -333,45 +337,53 @@ def main():
else: else:
feat = cut.compute_features(extractor=Fbank()) feat = cut.compute_features(extractor=Fbank())
feat_lens = feat.shape[0] feat_lens = feat.shape[0]
cur_recording = cut.recording.id cur_recording = cut.recording.id
if cur_recording != last_recording: if cur_recording != last_recording:
last_recording = cur_recording last_recording = cur_recording
history = [] # clean history history = [] # clean up the history
last_end = -1 last_end = -1
logging.info(f"Moving on to the next recording") logging.info("Moving on to the next recording")
else: else:
if cut.start < last_end - 0.2: # overlap exits if cut.start < last_end - 0.2: # overlap with the previous cuts
logging.warning(f"An overlap exists between current cut and last cut") logging.warning("An overlap exists between current cut and last cut")
logging.warning("Skipping this cut!") logging.warning("Skipping this cut!")
continue continue
if cut.start > last_end + 10: if cut.start > last_end + 10:
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.") logging.warning(
f"Large time gap between the current and previous utterance: {cut.start - last_end}."
)
# prepare input # prepare input
x = torch.tensor(feat, device=device).unsqueeze(0) x = torch.tensor(feat, device=device).unsqueeze(0)
x_lens = torch.tensor([feat_lens,], device=device) x_lens = torch.tensor(
[
feat_lens,
],
device=device,
)
if params.use_pre_text: if params.use_pre_text:
if params.num_history > 0: if params.num_history > 0:
pre_texts = history[-params.num_history:] pre_texts = history[-params.num_history :]
else: else:
pre_texts = [] pre_texts = []
num_pre_texts.append(len(pre_texts)) num_pre_texts.append(len(pre_texts))
pre_texts = [train_text_normalization(" ".join(pre_texts))] pre_texts = [train_text_normalization(" ".join(pre_texts))]
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
style_texts = [fixed_sentence] style_texts = [fixed_sentence]
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if params.use_style_prompt: if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform) style_texts = _apply_style_transform(
style_texts, params.style_text_transform
# encode pre_text )
# encode prompts
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
pre_texts=pre_texts, pre_texts=pre_texts,
style_texts=style_texts, style_texts=style_texts,
@ -380,16 +392,18 @@ def main():
no_limit=True, no_limit=True,
) )
if params.num_history > 5: if params.num_history > 5:
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ") logging.info(
f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} "
)
memory, memory_key_padding_mask = model.encode_text( memory, memory_key_padding_mask = model.encode_text(
encoded_inputs=encoded_inputs, encoded_inputs=encoded_inputs,
style_lens=style_lens, style_lens=style_lens,
) # (T,B,C) ) # (T,B,C)
else: else:
memory = None memory = None
memory_key_padding_mask = None memory_key_padding_mask = None
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
encoder_out, encoder_out_lens = model.encode_audio( encoder_out, encoder_out_lens = model.encode_audio(
@ -398,7 +412,7 @@ def main():
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
if params.method == "greedy_search": if params.method == "greedy_search":
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
@ -412,17 +426,19 @@ def main():
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
hyp = sp.decode(hyp_tokens)[0] # in string format hyp = sp.decode(hyp_tokens)[0] # in string format
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training ref_text = ref_text_normalization(
cut.supervisions[0].texts[0]
# extend the history, the history here is in original format ) # required to match the training
# extend the history
if params.use_gt_pre_text: if params.use_gt_pre_text:
history.append(ref_text) history.append(ref_text)
else: else:
history.append(hyp) history.append(hyp)
last_end = cut.end # update the last end timestamp last_end = cut.end # update the last end timestamp
# append the current decoding result # append the current decoding result
hyp = hyp.split() hyp = hyp.split()
ref = ref_text.split() ref = ref_text.split()
@ -431,45 +447,69 @@ def main():
count += 1 count += 1
if count % 100 == 0: if count % 100 == 0:
logging.info(f"Cuts processed until now: {count}/{len(manifest)}") logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}") logging.info(
f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}"
)
logging.info(f"A total of {count} cuts") logging.info(f"A total of {count} cuts")
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}") logging.info(
f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}"
)
results = sorted(results) results = sorted(results)
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt" recog_path = (
params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt" errs_filename = (
params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False, f,
f"long-audio-{params.method}",
results,
enable_log=True,
compute_CER=False,
) )
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
if params.post_normalization: if params.post_normalization:
params.suffix += "-post-normalization" params.suffix += "-post-normalization"
new_res = [] new_res = []
for item in results: for item in results:
id, ref, hyp = item id, ref, hyp = item
hyp = upper_only_alpha(" ".join(hyp)).split() hyp = upper_only_alpha(" ".join(hyp)).split()
ref = upper_only_alpha(" ".join(ref)).split() ref = upper_only_alpha(" ".join(ref)).split()
new_res.append((id,ref,hyp)) new_res.append((id, ref, hyp))
new_res = sorted(new_res) new_res = sorted(new_res)
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" recog_path = (
params.res_dir
/ f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
)
store_transcripts(filename=recog_path, texts=new_res) store_transcripts(filename=recog_path, texts=new_res)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" errs_filename = (
params.res_dir
/ f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False, f,
f"long-audio-{params.method}",
new_res,
enable_log=True,
compute_CER=False,
) )
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
if __name__=="__main__":
main()
if __name__ == "__main__":
main()