mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
fix tts stage decode
This commit is contained in:
parent
5a7c72cb47
commit
49256fa917
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||||
|
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=$1
|
stage=$1
|
||||||
@ -121,7 +121,6 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
|||||||
$train_cmd_args
|
$train_cmd_args
|
||||||
fi
|
fi
|
||||||
|
|
||||||
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
|
||||||
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||||
log "stage 19: Training TTS Model"
|
log "stage 19: Training TTS Model"
|
||||||
exp_dir=./qwen_omni/exp_tts
|
exp_dir=./qwen_omni/exp_tts
|
||||||
@ -218,16 +217,17 @@ if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
|||||||
$train_cmd_args
|
$train_cmd_args
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||||
log "stage 21: TTS Decoding Test Set"
|
log "stage 21: TTS Decoding Test Set"
|
||||||
exp_dir=./qwen_omni/exp_tts
|
exp_dir=./qwen_omni/exp_tts
|
||||||
torchrun --nproc_per_node=4 python3 ./qwen_omni/decode_tts.py \
|
torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--speech-encoder-path-or-name models/large-v2.pt \
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
|
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--enable-speech-output True \
|
--enable-speech-output True \
|
||||||
--token2wav-path /lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice2-0.5B \
|
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||||
--use-lora True
|
--use-lora True
|
||||||
fi
|
fi
|
||||||
|
@ -63,16 +63,9 @@ from model import SPEECH_LLM, EncoderProjector
|
|||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||||
|
from utils import AttributeDict, setup_logger, store_transcripts, write_error_stats
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
|
|
||||||
@ -418,11 +411,7 @@ def get_parser():
|
|||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
def get_params() -> AttributeDict:
|
||||||
params = AttributeDict(
|
params = AttributeDict({})
|
||||||
{
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,36 +40,34 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import load_dataset
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
from datasets import Audio, load_dataset
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from decode import audio_decode_cosyvoice2
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from label_smoothing import LabelSmoothingLoss
|
||||||
|
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM
|
from model import IGNORE_TOKEN_ID, SPEECH_LLM
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from train import add_model_arguments, add_training_arguments, get_model, get_params
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
Qwen2Config,
|
Qwen2Config,
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
)
|
)
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
|
||||||
from torch.utils.data import DistributedSampler, DataLoader
|
|
||||||
|
|
||||||
from train import add_model_arguments, add_training_arguments, get_params, get_model
|
|
||||||
from decode import audio_decode_cosyvoice2
|
|
||||||
from utils import ( # filter_uneven_sized_batch,
|
from utils import ( # filter_uneven_sized_batch,
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
@ -79,9 +77,9 @@ from utils import ( # filter_uneven_sized_batch,
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
|
||||||
sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
|
|
||||||
|
|
||||||
|
# sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||||
try:
|
try:
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
@ -116,9 +114,11 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
add_training_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
messages,
|
messages,
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
@ -177,30 +177,41 @@ def preprocess(
|
|||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
return input_ids, attention_mask, target_ids
|
return input_ids, attention_mask, target_ids
|
||||||
|
|
||||||
|
|
||||||
def data_collator(batch):
|
def data_collator(batch):
|
||||||
prompt_texts, prompt_speech_16k, messages, ids = [], [], [], []
|
prompt_texts, prompt_speech_16k, messages, ids, target_texts = [], [], [], [], []
|
||||||
for i, item in enumerate(batch):
|
for i, item in enumerate(batch):
|
||||||
# speech_tokens.append(item["prompt_audio_cosy2_tokens"])
|
# speech_tokens.append(item["prompt_audio_cosy2_tokens"])
|
||||||
message_list_item = []
|
message_list_item = []
|
||||||
message_list_item += [
|
message_list_item += [
|
||||||
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}"},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}",
|
||||||
|
},
|
||||||
{"role": "assistant", "content": ""},
|
{"role": "assistant", "content": ""},
|
||||||
]
|
]
|
||||||
messages.append(message_list_item)
|
messages.append(message_list_item)
|
||||||
|
target_texts.append(item["target_text"])
|
||||||
|
|
||||||
ids.append(item["id"])
|
ids.append(item["id"])
|
||||||
prompt_texts.append(item["prompt_text"])
|
prompt_texts.append(item["prompt_text"])
|
||||||
prompt_speech_16k.append(item["prompt_audio"])
|
speech_org = item["prompt_audio"]
|
||||||
print(item["prompt_audio"], 233333333333333333)
|
|
||||||
|
|
||||||
|
speech_org = torch.tensor(speech_org["array"], dtype=torch.float32).unsqueeze(0)
|
||||||
|
speech_org = speech_org.mean(dim=0, keepdim=True)
|
||||||
|
prompt_speech_16k.append(speech_org)
|
||||||
|
|
||||||
|
# resample to 16k
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"prompt_texts": prompt_texts,
|
"prompt_texts": prompt_texts,
|
||||||
|
"target_texts": target_texts,
|
||||||
"prompt_speech_16k": prompt_speech_16k,
|
"prompt_speech_16k": prompt_speech_16k,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"ids": ids,
|
"ids": ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def run(rank, world_size, args):
|
def run(rank, world_size, args):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -215,7 +226,7 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params.log_dir = Path(params.exp_dir) / f"log-results-wav"
|
params.log_dir = Path(params.exp_dir) / "log-results-wav"
|
||||||
params.log_dir.mkdir(parents=True, exist_ok=True)
|
params.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
@ -232,11 +243,9 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
assert params.deepspeed and world_size > 1
|
|
||||||
logging.info("Using DeepSpeed")
|
|
||||||
|
|
||||||
dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
|
dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
|
||||||
|
dataset = dataset.cast_column("prompt_audio", Audio(sampling_rate=16000))
|
||||||
|
|
||||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@ -245,7 +254,7 @@ def run(rank, world_size, args):
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
prefetch_factor=1,
|
prefetch_factor=1,
|
||||||
collate_fn=data_collator
|
collate_fn=data_collator,
|
||||||
)
|
)
|
||||||
token2wav_model = CosyVoice2(
|
token2wav_model = CosyVoice2(
|
||||||
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||||
@ -254,6 +263,7 @@ def run(rank, world_size, args):
|
|||||||
messages = batch["messages"]
|
messages = batch["messages"]
|
||||||
prompt_texts = batch["prompt_texts"]
|
prompt_texts = batch["prompt_texts"]
|
||||||
prompt_speech_16k = batch["prompt_speech_16k"]
|
prompt_speech_16k = batch["prompt_speech_16k"]
|
||||||
|
target_texts = batch["target_texts"]
|
||||||
ids = batch["ids"]
|
ids = batch["ids"]
|
||||||
input_ids, attention_mask, _ = preprocess(messages, tokenizer)
|
input_ids, attention_mask, _ = preprocess(messages, tokenizer)
|
||||||
generated_ids, generated_speech_output = model.decode_with_speech_output(
|
generated_ids, generated_speech_output = model.decode_with_speech_output(
|
||||||
@ -262,8 +272,13 @@ def run(rank, world_size, args):
|
|||||||
generated_speech_output = [
|
generated_speech_output = [
|
||||||
generated_speech_output
|
generated_speech_output
|
||||||
] # WAR: only support batch = 1 for now
|
] # WAR: only support batch = 1 for now
|
||||||
for cut_id, audio_tokens, prompt_text, prompt_speech in zip(ids, generated_speech_output, prompt_texts, prompt_speech_16k):
|
for cut_id, audio_tokens, prompt_text, prompt_speech, target_text in zip(
|
||||||
|
ids, generated_speech_output, prompt_texts, prompt_speech_16k, target_texts
|
||||||
|
):
|
||||||
speech_file_name = params.log_dir / f"{cut_id}.wav"
|
speech_file_name = params.log_dir / f"{cut_id}.wav"
|
||||||
|
# save target_text to file
|
||||||
|
with open(params.log_dir / f"{cut_id}.txt", "w") as f:
|
||||||
|
f.write(f"{target_text}\n")
|
||||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||||
if "CosyVoice2" in params.token2wav_path:
|
if "CosyVoice2" in params.token2wav_path:
|
||||||
audio_hat = audio_decode_cosyvoice2(
|
audio_hat = audio_decode_cosyvoice2(
|
||||||
@ -276,6 +291,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -285,7 +301,7 @@ def main():
|
|||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
# torch.set_num_interop_threads(1)
|
||||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
run(rank=rank, world_size=world_size, args=args)
|
run(rank=rank, world_size=world_size, args=args)
|
||||||
|
|
||||||
|
@ -11,15 +11,16 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||||
|
|
||||||
|
import kaldialign
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
Pathlike = Union[str, Path]
|
Pathlike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size():
|
||||||
if "WORLD_SIZE" in os.environ:
|
if "WORLD_SIZE" in os.environ:
|
||||||
return int(os.environ["WORLD_SIZE"])
|
return int(os.environ["WORLD_SIZE"])
|
||||||
@ -37,6 +38,7 @@ def get_rank():
|
|||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def get_local_rank():
|
def get_local_rank():
|
||||||
if "LOCAL_RANK" in os.environ:
|
if "LOCAL_RANK" in os.environ:
|
||||||
return int(os.environ["LOCAL_RANK"])
|
return int(os.environ["LOCAL_RANK"])
|
||||||
@ -45,6 +47,7 @@ def get_local_rank():
|
|||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||||
that a type is a bool type and user can enter
|
that a type is a bool type and user can enter
|
||||||
@ -63,6 +66,7 @@ def str2bool(v):
|
|||||||
else:
|
else:
|
||||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||||
|
|
||||||
|
|
||||||
class AttributeDict(dict):
|
class AttributeDict(dict):
|
||||||
def __getattr__(self, key):
|
def __getattr__(self, key):
|
||||||
if key in self:
|
if key in self:
|
||||||
@ -87,6 +91,7 @@ class AttributeDict(dict):
|
|||||||
tmp[k] = v
|
tmp[k] = v
|
||||||
return json.dumps(tmp, indent=indent, sort_keys=True)
|
return json.dumps(tmp, indent=indent, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(
|
def setup_logger(
|
||||||
log_filename: Pathlike,
|
log_filename: Pathlike,
|
||||||
log_level: str = "info",
|
log_level: str = "info",
|
||||||
@ -139,6 +144,7 @@ def setup_logger(
|
|||||||
console.setFormatter(logging.Formatter(formatter))
|
console.setFormatter(logging.Formatter(formatter))
|
||||||
logging.getLogger("").addHandler(console)
|
logging.getLogger("").addHandler(console)
|
||||||
|
|
||||||
|
|
||||||
class MetricsTracker(collections.defaultdict):
|
class MetricsTracker(collections.defaultdict):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Passing the type 'int' to the base-class constructor
|
# Passing the type 'int' to the base-class constructor
|
||||||
@ -228,4 +234,200 @@ class MetricsTracker(collections.defaultdict):
|
|||||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||||
"""
|
"""
|
||||||
for k, v in self.norm_items():
|
for k, v in self.norm_items():
|
||||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def store_transcripts(
|
||||||
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Save predicted results and reference transcripts to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
File to save the results to.
|
||||||
|
texts:
|
||||||
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
||||||
|
strings.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf8") as f:
|
||||||
|
for cut_id, ref, hyp in texts:
|
||||||
|
if char_level:
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_error_stats(
|
||||||
|
f: TextIO,
|
||||||
|
test_set_name: str,
|
||||||
|
results: List[Tuple[str, str]],
|
||||||
|
enable_log: bool = True,
|
||||||
|
compute_CER: bool = False,
|
||||||
|
sclite_mode: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Write statistics based on predicted results and reference transcripts.
|
||||||
|
|
||||||
|
It will write the following to the given file:
|
||||||
|
|
||||||
|
- WER
|
||||||
|
- number of insertions, deletions, substitutions, corrects and total
|
||||||
|
reference words. For example::
|
||||||
|
|
||||||
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||||
|
reference words (2337 correct)
|
||||||
|
|
||||||
|
- The difference between the reference transcript and predicted result.
|
||||||
|
An instance is given below::
|
||||||
|
|
||||||
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||||
|
|
||||||
|
The above example shows that the reference word is `EDISON`,
|
||||||
|
but it is predicted to `ADDISON` (a substitution error).
|
||||||
|
|
||||||
|
Another example is::
|
||||||
|
|
||||||
|
FOR THE FIRST DAY (SIR->*) I THINK
|
||||||
|
|
||||||
|
The reference word `SIR` is missing in the predicted
|
||||||
|
results (a deletion error).
|
||||||
|
results:
|
||||||
|
An iterable of tuples. The first element is the cut_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
enable_log:
|
||||||
|
If True, also print detailed WER to the console.
|
||||||
|
Otherwise, it is written only to the given file.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||||
|
ins: Dict[str, int] = defaultdict(int)
|
||||||
|
dels: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# `words` stores counts per word, as follows:
|
||||||
|
# corr, ref_sub, hyp_sub, ins, dels
|
||||||
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
|
num_corr = 0
|
||||||
|
ERR = "*"
|
||||||
|
|
||||||
|
if compute_CER:
|
||||||
|
for i, res in enumerate(results):
|
||||||
|
cut_id, ref, hyp = res
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
results[i] = (cut_id, ref, hyp)
|
||||||
|
|
||||||
|
for cut_id, ref, hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||||
|
for ref_word, hyp_word in ali:
|
||||||
|
if ref_word == ERR:
|
||||||
|
ins[hyp_word] += 1
|
||||||
|
words[hyp_word][3] += 1
|
||||||
|
elif hyp_word == ERR:
|
||||||
|
dels[ref_word] += 1
|
||||||
|
words[ref_word][4] += 1
|
||||||
|
elif hyp_word != ref_word:
|
||||||
|
subs[(ref_word, hyp_word)] += 1
|
||||||
|
words[ref_word][1] += 1
|
||||||
|
words[hyp_word][2] += 1
|
||||||
|
else:
|
||||||
|
words[ref_word][0] += 1
|
||||||
|
num_corr += 1
|
||||||
|
ref_len = sum([len(r) for _, r, _ in results])
|
||||||
|
sub_errs = sum(subs.values())
|
||||||
|
ins_errs = sum(ins.values())
|
||||||
|
del_errs = sum(dels.values())
|
||||||
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info(
|
||||||
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||||
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||||
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
|
print(
|
||||||
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||||
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||||
|
f"words ({num_corr} correct)",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||||
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
|
for cut_id, ref, hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
combine_successive_errors = True
|
||||||
|
if combine_successive_errors:
|
||||||
|
ali = [[[x], [y]] for x, y in ali]
|
||||||
|
for i in range(len(ali) - 1):
|
||||||
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||||
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||||
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||||
|
ali[i] = [[], []]
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
list(filter(lambda a: a != ERR, x)),
|
||||||
|
list(filter(lambda a: a != ERR, y)),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
ali = list(filter(lambda x: x != [[], []], ali))
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
ERR if x == [] else " ".join(x),
|
||||||
|
ERR if y == [] else " ".join(y),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{cut_id}:\t"
|
||||||
|
+ " ".join(
|
||||||
|
(
|
||||||
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||||
|
for ref_word, hyp_word in ali
|
||||||
|
)
|
||||||
|
),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||||
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("DELETIONS: count ref", file=f)
|
||||||
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||||
|
print(f"{count} {ref}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("INSERTIONS: count hyp", file=f)
|
||||||
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||||
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||||
|
for _, word, counts in sorted(
|
||||||
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
|
):
|
||||||
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||||
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||||
|
ref_count = corr + ref_sub + dels
|
||||||
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
|
return float(tot_err_rate)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user