mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
updates
This commit is contained in:
parent
93461fb77e
commit
cda6e06a85
@ -450,6 +450,7 @@ def decode_one_batch(
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
# get the librispeech biasing data
|
||||
if params.use_ls_context_list and params.use_ls_test_set:
|
||||
if params.biasing_level == "utterance":
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
@ -476,7 +477,6 @@ def decode_one_batch(
|
||||
|
||||
# Get the text embedding
|
||||
if params.use_pre_text or params.use_style_prompt:
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
if not params.use_ls_context_list:
|
||||
|
@ -15,17 +15,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import warnings
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear, penalize_abs_values_gt
|
||||
from torch import Tensor
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -185,11 +186,6 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
@ -264,4 +260,3 @@ class Transducer(nn.Module):
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
@ -15,17 +15,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import warnings
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear, penalize_abs_values_gt
|
||||
from torch import Tensor
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
|
||||
class PromptedTransducer(nn.Module):
|
||||
@ -98,11 +99,19 @@ class PromptedTransducer(nn.Module):
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
|
||||
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
|
||||
self.context_fuser = context_fuser
|
||||
|
||||
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
|
||||
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
|
||||
assert text_encoder_type in (
|
||||
"BERT",
|
||||
"DistilBERT",
|
||||
"BERT-UNCASED",
|
||||
), f"Unseen text_encoder type {text_encoder_type}"
|
||||
self.text_encoder_dim = (
|
||||
self.text_encoder.config.hidden_size
|
||||
if text_encoder_type in ("BERT", "BERT-UNCASED")
|
||||
else self.text_encoder.config.dim
|
||||
)
|
||||
|
||||
if text_encoder_adapter:
|
||||
self.text_encoder_adapter = nn.Sequential(
|
||||
@ -112,7 +121,9 @@ class PromptedTransducer(nn.Module):
|
||||
else:
|
||||
self.text_encoder_adapter = None
|
||||
|
||||
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
|
||||
self.style_prompt_embedding = nn.Parameter(
|
||||
torch.full((self.text_encoder_dim,), 0.5)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -184,8 +195,7 @@ class PromptedTransducer(nn.Module):
|
||||
|
||||
if use_pre_text:
|
||||
memory, memory_key_padding_mask = self.encode_text(
|
||||
encoded_inputs,
|
||||
style_lens=style_lens
|
||||
encoded_inputs, style_lens=style_lens
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
@ -231,11 +241,6 @@ class PromptedTransducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
@ -270,9 +275,9 @@ class PromptedTransducer(nn.Module):
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
if self.context_fuser is not None and memory is not None:
|
||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
||||
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
|
||||
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
|
||||
context = self.joiner.context_proj(context)
|
||||
context = self.joiner.context_proj(context)
|
||||
else:
|
||||
context = None
|
||||
|
||||
@ -304,13 +309,14 @@ class PromptedTransducer(nn.Module):
|
||||
(memory_len, batch_size, embed_dim) = memory.shape
|
||||
|
||||
indicator = (
|
||||
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
|
||||
< style_lens
|
||||
torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
|
||||
)
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
|
||||
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(
|
||||
memory_len, batch_size, self.text_encoder_dim
|
||||
)
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
@ -328,12 +334,12 @@ class PromptedTransducer(nn.Module):
|
||||
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
||||
text_encoder and the attention mask
|
||||
"""
|
||||
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
|
||||
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
|
||||
|
||||
# Freeze the pre-trained text encoder
|
||||
with torch.no_grad():
|
||||
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
|
||||
memory = memory.permute(1,0,2)
|
||||
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
|
||||
memory = memory.permute(1, 0, 2)
|
||||
|
||||
# Text encoder adapter
|
||||
if self.text_encoder_adapter is not None:
|
||||
|
@ -1,12 +1,29 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def train_text_normalization(s: str) -> str:
|
||||
# replace full-width with half-width
|
||||
s = s.replace("“", '"')
|
||||
s = s.replace("”", '"')
|
||||
s = s.replace("‘", "'")
|
||||
s = s.replace("’", "'")
|
||||
if s[:2] == "\" ": # remove the starting double quote
|
||||
if s[:2] == '" ': # remove the starting double quote
|
||||
s = s[2:]
|
||||
|
||||
return s
|
||||
@ -17,8 +34,6 @@ def ref_text_normalization(ref_text: str) -> str:
|
||||
p = r"[FN#[0-9]*]"
|
||||
pattern = re.compile(p)
|
||||
|
||||
# ref_text = ref_text.replace("”", "\"")
|
||||
# ref_text = ref_text.replace("’", "'")
|
||||
res = pattern.findall(ref_text)
|
||||
ref_text = re.sub(p, "", ref_text)
|
||||
|
||||
@ -27,32 +42,34 @@ def ref_text_normalization(ref_text: str) -> str:
|
||||
return ref_text
|
||||
|
||||
|
||||
def remove_non_alphabetic(text: str, strict: bool=True) -> str:
|
||||
def remove_non_alphabetic(text: str, strict: bool = True) -> str:
|
||||
# Recommend to set strict to False
|
||||
if not strict:
|
||||
# Note, this also keeps space, single quote(') and hypen (-)
|
||||
text = text.replace("-", " ")
|
||||
text = text.replace("—", " ")
|
||||
return re.sub("[^a-zA-Z0-9\s']+", "", text)
|
||||
return re.sub(r"[^a-zA-Z0-9\s']+", "", text)
|
||||
else:
|
||||
# only keeps space
|
||||
return re.sub("[^a-zA-Z\s]+", "", text)
|
||||
return re.sub(r"[^a-zA-Z\s]+", "", text)
|
||||
|
||||
|
||||
def recog_text_normalization(recog_text: str) -> str:
|
||||
pass
|
||||
|
||||
def upper_only_alpha(text: str) -> str:
|
||||
return remove_non_alphabetic(text.upper(), strict=False)
|
||||
|
||||
|
||||
def lower_only_alpha(text: str) -> str:
|
||||
return remove_non_alphabetic(text.lower(), strict=False)
|
||||
|
||||
|
||||
def lower_all_char(text: str) -> str:
|
||||
return text.lower()
|
||||
|
||||
|
||||
def upper_all_char(text: str) -> str:
|
||||
return text.upper()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
print(ref_text)
|
||||
|
@ -1,8 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,)
|
||||
# Zengwei Yao)
|
||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang,
|
||||
#
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
|
@ -1,18 +1,47 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python ./zipformer_prompt_asr/transcribe_bert.py \
|
||||
--epoch 50 \
|
||||
--avg 10 \
|
||||
--exp-dir ./zipformer_prompt_asr/exp \
|
||||
--manifest-dir data/long_audios/long_audio.jsonl.gz \
|
||||
--pre-text-transform mixed-punc \
|
||||
--style-text-transform mixed-punc \
|
||||
--num-history 5 \
|
||||
--use-pre-text True \
|
||||
--use-gt-pre-text False
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from lhotse import load_manifest, Fbank
|
||||
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
@ -20,21 +49,24 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from decode_bert import _apply_style_transform
|
||||
from lhotse import Fbank, load_manifest
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from train_bert_encoder_with_style import (
|
||||
from tqdm import tqdm
|
||||
from train_bert_encoder import (
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
_encode_texts_as_bytes_with_tokenizer,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -51,6 +83,7 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -91,7 +124,6 @@ def get_parser():
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
@ -120,53 +152,47 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
|
||||
default="data/long_audios/long_audio.jsonl.gz",
|
||||
help="""This is the manfiest for long audio transcription.
|
||||
It is intended to be sored, i.e first sort by recording ID and then sort by
|
||||
start timestamp"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--segment-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
The cust are intended to be sorted, i.e first sort by recording ID and
|
||||
then sort by start timestamp""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use pre-text when decoding the current chunk"
|
||||
help="Whether use pre-text when decoding the current chunk",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
help="Use style prompt when evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
help="The style of content prompt, i.e pre_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
help="The style of style prompt, i.e style_text",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-history",
|
||||
type=int,
|
||||
default=2,
|
||||
help="How many previous chunks to look if using pre-text for decoding"
|
||||
help="How many previous chunks to look if using pre-text for decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -186,28 +212,6 @@ def get_parser():
|
||||
|
||||
return parser
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
@ -236,19 +240,20 @@ def main():
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.method:
|
||||
params.suffix += (
|
||||
f"-{params.method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-{params.method}-beam-size-{params.beam_size}"
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.use_gt_pre_text:
|
||||
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
else:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
params.suffix += (
|
||||
f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
)
|
||||
|
||||
|
||||
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
|
||||
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info")
|
||||
book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "")
|
||||
setup_logger(
|
||||
f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info"
|
||||
)
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
@ -265,13 +270,12 @@ def main():
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
@ -334,29 +338,35 @@ def main():
|
||||
feat = cut.compute_features(extractor=Fbank())
|
||||
feat_lens = feat.shape[0]
|
||||
|
||||
|
||||
cur_recording = cut.recording.id
|
||||
|
||||
if cur_recording != last_recording:
|
||||
last_recording = cur_recording
|
||||
history = [] # clean history
|
||||
history = [] # clean up the history
|
||||
last_end = -1
|
||||
logging.info(f"Moving on to the next recording")
|
||||
logging.info("Moving on to the next recording")
|
||||
else:
|
||||
if cut.start < last_end - 0.2: # overlap exits
|
||||
logging.warning(f"An overlap exists between current cut and last cut")
|
||||
if cut.start < last_end - 0.2: # overlap with the previous cuts
|
||||
logging.warning("An overlap exists between current cut and last cut")
|
||||
logging.warning("Skipping this cut!")
|
||||
continue
|
||||
if cut.start > last_end + 10:
|
||||
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.")
|
||||
logging.warning(
|
||||
f"Large time gap between the current and previous utterance: {cut.start - last_end}."
|
||||
)
|
||||
|
||||
# prepare input
|
||||
x = torch.tensor(feat, device=device).unsqueeze(0)
|
||||
x_lens = torch.tensor([feat_lens,], device=device)
|
||||
x_lens = torch.tensor(
|
||||
[
|
||||
feat_lens,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.num_history > 0:
|
||||
pre_texts = history[-params.num_history:]
|
||||
pre_texts = history[-params.num_history :]
|
||||
else:
|
||||
pre_texts = []
|
||||
num_pre_texts.append(len(pre_texts))
|
||||
@ -366,9 +376,11 @@ def main():
|
||||
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
style_texts = _apply_style_transform(
|
||||
style_texts, params.style_text_transform
|
||||
)
|
||||
|
||||
# encode pre_text
|
||||
# encode prompts
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
@ -380,12 +392,14 @@ def main():
|
||||
no_limit=True,
|
||||
)
|
||||
if params.num_history > 5:
|
||||
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
|
||||
logging.info(
|
||||
f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} "
|
||||
)
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
encoded_inputs=encoded_inputs,
|
||||
style_lens=style_lens,
|
||||
) # (T,B,C)
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
@ -413,15 +427,17 @@ def main():
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
hyp = sp.decode(hyp_tokens)[0] # in string format
|
||||
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
|
||||
hyp = sp.decode(hyp_tokens)[0] # in string format
|
||||
ref_text = ref_text_normalization(
|
||||
cut.supervisions[0].texts[0]
|
||||
) # required to match the training
|
||||
|
||||
# extend the history, the history here is in original format
|
||||
# extend the history
|
||||
if params.use_gt_pre_text:
|
||||
history.append(ref_text)
|
||||
else:
|
||||
history.append(hyp)
|
||||
last_end = cut.end # update the last end timestamp
|
||||
last_end = cut.end # update the last end timestamp
|
||||
|
||||
# append the current decoding result
|
||||
hyp = hyp.split()
|
||||
@ -431,20 +447,32 @@ def main():
|
||||
count += 1
|
||||
if count % 100 == 0:
|
||||
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
|
||||
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}")
|
||||
logging.info(
|
||||
f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}"
|
||||
)
|
||||
|
||||
logging.info(f"A total of {count} cuts")
|
||||
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}")
|
||||
logging.info(
|
||||
f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}"
|
||||
)
|
||||
|
||||
results = sorted(results)
|
||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
|
||||
f,
|
||||
f"long-audio-{params.method}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
@ -457,19 +485,31 @@ def main():
|
||||
id, ref, hyp = item
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_res.append((id,ref,hyp))
|
||||
new_res.append((id, ref, hyp))
|
||||
|
||||
new_res = sorted(new_res)
|
||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
recog_path = (
|
||||
params.res_dir
|
||||
/ f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=new_res)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
errs_filename = (
|
||||
params.res_dir
|
||||
/ f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False,
|
||||
f,
|
||||
f"long-audio-{params.method}",
|
||||
new_res,
|
||||
enable_log=True,
|
||||
compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
if __name__=="__main__":
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user