mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
updates
This commit is contained in:
parent
93461fb77e
commit
cda6e06a85
@ -450,6 +450,7 @@ def decode_one_batch(
|
|||||||
else:
|
else:
|
||||||
pre_texts = ["" for _ in range(batch_size)]
|
pre_texts = ["" for _ in range(batch_size)]
|
||||||
|
|
||||||
|
# get the librispeech biasing data
|
||||||
if params.use_ls_context_list and params.use_ls_test_set:
|
if params.use_ls_context_list and params.use_ls_test_set:
|
||||||
if params.biasing_level == "utterance":
|
if params.biasing_level == "utterance":
|
||||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||||
@ -476,7 +477,6 @@ def decode_one_batch(
|
|||||||
|
|
||||||
# Get the text embedding
|
# Get the text embedding
|
||||||
if params.use_pre_text or params.use_style_prompt:
|
if params.use_pre_text or params.use_style_prompt:
|
||||||
|
|
||||||
# apply style transform to the pre_text and style_text
|
# apply style transform to the pre_text and style_text
|
||||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||||
if not params.use_ls_context_list:
|
if not params.use_ls_context_list:
|
||||||
|
@ -15,17 +15,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import random
|
|
||||||
import warnings
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear, penalize_abs_values_gt
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
|
||||||
from torch import Tensor
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -185,11 +186,6 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
# if self.training and random.random() < 0.25:
|
|
||||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
|
||||||
# if self.training and random.random() < 0.25:
|
|
||||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
@ -257,11 +253,10 @@ class Transducer(nn.Module):
|
|||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = self.encoder(
|
encoder_out, encoder_out_lens = self.encoder(
|
||||||
x=x,
|
x=x,
|
||||||
x_lens=x_lens,
|
x_lens=x_lens,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
@ -15,17 +15,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import random
|
|
||||||
import warnings
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear, penalize_abs_values_gt
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
|
||||||
from torch import Tensor
|
|
||||||
from typing import Optional, Tuple, Dict
|
|
||||||
|
|
||||||
|
|
||||||
class PromptedTransducer(nn.Module):
|
class PromptedTransducer(nn.Module):
|
||||||
@ -97,13 +98,21 @@ class PromptedTransducer(nn.Module):
|
|||||||
vocab_size,
|
vocab_size,
|
||||||
initial_scale=0.25,
|
initial_scale=0.25,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
|
self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT
|
||||||
self.context_fuser = context_fuser
|
self.context_fuser = context_fuser
|
||||||
|
|
||||||
assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}"
|
assert text_encoder_type in (
|
||||||
self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim
|
"BERT",
|
||||||
|
"DistilBERT",
|
||||||
|
"BERT-UNCASED",
|
||||||
|
), f"Unseen text_encoder type {text_encoder_type}"
|
||||||
|
self.text_encoder_dim = (
|
||||||
|
self.text_encoder.config.hidden_size
|
||||||
|
if text_encoder_type in ("BERT", "BERT-UNCASED")
|
||||||
|
else self.text_encoder.config.dim
|
||||||
|
)
|
||||||
|
|
||||||
if text_encoder_adapter:
|
if text_encoder_adapter:
|
||||||
self.text_encoder_adapter = nn.Sequential(
|
self.text_encoder_adapter = nn.Sequential(
|
||||||
nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False),
|
nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False),
|
||||||
@ -111,8 +120,10 @@ class PromptedTransducer(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.text_encoder_adapter = None
|
self.text_encoder_adapter = None
|
||||||
|
|
||||||
self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5))
|
self.style_prompt_embedding = nn.Parameter(
|
||||||
|
torch.full((self.text_encoder_dim,), 0.5)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -181,11 +192,10 @@ class PromptedTransducer(nn.Module):
|
|||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
# freeze the BERT text encoder
|
# freeze the BERT text encoder
|
||||||
|
|
||||||
if use_pre_text:
|
if use_pre_text:
|
||||||
memory, memory_key_padding_mask = self.encode_text(
|
memory, memory_key_padding_mask = self.encode_text(
|
||||||
encoded_inputs,
|
encoded_inputs, style_lens=style_lens
|
||||||
style_lens=style_lens
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
memory = None
|
memory = None
|
||||||
@ -231,11 +241,6 @@ class PromptedTransducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
# if self.training and random.random() < 0.25:
|
|
||||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
|
||||||
# if self.training and random.random() < 0.25:
|
|
||||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
@ -270,12 +275,12 @@ class PromptedTransducer(nn.Module):
|
|||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
if self.context_fuser is not None and memory is not None:
|
if self.context_fuser is not None and memory is not None:
|
||||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C)
|
||||||
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
|
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
|
||||||
context = self.joiner.context_proj(context)
|
context = self.joiner.context_proj(context)
|
||||||
else:
|
else:
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
@ -304,16 +309,17 @@ class PromptedTransducer(nn.Module):
|
|||||||
(memory_len, batch_size, embed_dim) = memory.shape
|
(memory_len, batch_size, embed_dim) = memory.shape
|
||||||
|
|
||||||
indicator = (
|
indicator = (
|
||||||
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
|
torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
|
||||||
< style_lens
|
|
||||||
)
|
)
|
||||||
indicator = indicator.to(memory.dtype)
|
indicator = indicator.to(memory.dtype)
|
||||||
|
|
||||||
extra_term = torch.zeros_like(memory)
|
extra_term = torch.zeros_like(memory)
|
||||||
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim)
|
extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(
|
||||||
|
memory_len, batch_size, self.text_encoder_dim
|
||||||
|
)
|
||||||
|
|
||||||
return memory + extra_term
|
return memory + extra_term
|
||||||
|
|
||||||
def encode_text(
|
def encode_text(
|
||||||
self,
|
self,
|
||||||
encoded_inputs: Dict,
|
encoded_inputs: Dict,
|
||||||
@ -326,25 +332,25 @@ class PromptedTransducer(nn.Module):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
||||||
text_encoder and the attention mask
|
text_encoder and the attention mask
|
||||||
"""
|
"""
|
||||||
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
|
text_lens = encoded_inputs.pop("length") # need to use pop to remove this item
|
||||||
|
|
||||||
# Freeze the pre-trained text encoder
|
# Freeze the pre-trained text encoder
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
|
memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C)
|
||||||
memory = memory.permute(1,0,2)
|
memory = memory.permute(1, 0, 2)
|
||||||
|
|
||||||
# Text encoder adapter
|
# Text encoder adapter
|
||||||
if self.text_encoder_adapter is not None:
|
if self.text_encoder_adapter is not None:
|
||||||
memory = self.text_encoder_adapter(memory)
|
memory = self.text_encoder_adapter(memory)
|
||||||
|
|
||||||
memory = self._add_style_indicator(memory, style_lens)
|
memory = self._add_style_indicator(memory, style_lens)
|
||||||
|
|
||||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||||
|
|
||||||
return memory, memory_key_padding_mask
|
return memory, memory_key_padding_mask
|
||||||
|
|
||||||
def encode_audio(
|
def encode_audio(
|
||||||
self,
|
self,
|
||||||
feature: Tensor,
|
feature: Tensor,
|
||||||
@ -368,14 +374,14 @@ class PromptedTransducer(nn.Module):
|
|||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = self.encoder(
|
encoder_out, encoder_out_lens = self.encoder(
|
||||||
x=x,
|
x=x,
|
||||||
x_lens=x_lens,
|
x_lens=x_lens,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,29 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def train_text_normalization(s: str) -> str:
|
def train_text_normalization(s: str) -> str:
|
||||||
|
# replace full-width with half-width
|
||||||
s = s.replace("“", '"')
|
s = s.replace("“", '"')
|
||||||
s = s.replace("”", '"')
|
s = s.replace("”", '"')
|
||||||
s = s.replace("‘", "'")
|
s = s.replace("‘", "'")
|
||||||
s = s.replace("’", "'")
|
s = s.replace("’", "'")
|
||||||
if s[:2] == "\" ": # remove the starting double quote
|
if s[:2] == '" ': # remove the starting double quote
|
||||||
s = s[2:]
|
s = s[2:]
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@ -17,42 +34,42 @@ def ref_text_normalization(ref_text: str) -> str:
|
|||||||
p = r"[FN#[0-9]*]"
|
p = r"[FN#[0-9]*]"
|
||||||
pattern = re.compile(p)
|
pattern = re.compile(p)
|
||||||
|
|
||||||
# ref_text = ref_text.replace("”", "\"")
|
|
||||||
# ref_text = ref_text.replace("’", "'")
|
|
||||||
res = pattern.findall(ref_text)
|
res = pattern.findall(ref_text)
|
||||||
ref_text = re.sub(p, "", ref_text)
|
ref_text = re.sub(p, "", ref_text)
|
||||||
|
|
||||||
ref_text = train_text_normalization(ref_text)
|
ref_text = train_text_normalization(ref_text)
|
||||||
|
|
||||||
return ref_text
|
return ref_text
|
||||||
|
|
||||||
|
|
||||||
def remove_non_alphabetic(text: str, strict: bool=True) -> str:
|
def remove_non_alphabetic(text: str, strict: bool = True) -> str:
|
||||||
|
# Recommend to set strict to False
|
||||||
if not strict:
|
if not strict:
|
||||||
# Note, this also keeps space, single quote(') and hypen (-)
|
# Note, this also keeps space, single quote(') and hypen (-)
|
||||||
text = text.replace("-", " ")
|
text = text.replace("-", " ")
|
||||||
text = text.replace("—", " ")
|
text = text.replace("—", " ")
|
||||||
return re.sub("[^a-zA-Z0-9\s']+", "", text)
|
return re.sub(r"[^a-zA-Z0-9\s']+", "", text)
|
||||||
else:
|
else:
|
||||||
# only keeps space
|
# only keeps space
|
||||||
return re.sub("[^a-zA-Z\s]+", "", text)
|
return re.sub(r"[^a-zA-Z\s]+", "", text)
|
||||||
|
|
||||||
|
|
||||||
def recog_text_normalization(recog_text: str) -> str:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def upper_only_alpha(text: str) -> str:
|
def upper_only_alpha(text: str) -> str:
|
||||||
return remove_non_alphabetic(text.upper(), strict=False)
|
return remove_non_alphabetic(text.upper(), strict=False)
|
||||||
|
|
||||||
|
|
||||||
def lower_only_alpha(text: str) -> str:
|
def lower_only_alpha(text: str) -> str:
|
||||||
return remove_non_alphabetic(text.lower(), strict=False)
|
return remove_non_alphabetic(text.lower(), strict=False)
|
||||||
|
|
||||||
|
|
||||||
def lower_all_char(text: str) -> str:
|
def lower_all_char(text: str) -> str:
|
||||||
return text.lower()
|
return text.lower()
|
||||||
|
|
||||||
|
|
||||||
def upper_all_char(text: str) -> str:
|
def upper_all_char(text: str) -> str:
|
||||||
return text.upper()
|
return text.upper()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||||
print(ref_text)
|
print(ref_text)
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang,
|
||||||
# Wei Kang,
|
#
|
||||||
# Mingshuang Luo,)
|
|
||||||
# Zengwei Yao)
|
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
@ -1,18 +1,47 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
python ./zipformer_prompt_asr/transcribe_bert.py \
|
||||||
|
--epoch 50 \
|
||||||
|
--avg 10 \
|
||||||
|
--exp-dir ./zipformer_prompt_asr/exp \
|
||||||
|
--manifest-dir data/long_audios/long_audio.jsonl.gz \
|
||||||
|
--pre-text-transform mixed-punc \
|
||||||
|
--style-text-transform mixed-punc \
|
||||||
|
--num-history 5 \
|
||||||
|
--use-pre-text True \
|
||||||
|
--use-gt-pre-text False
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from lhotse import load_manifest, Fbank
|
|
||||||
|
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
@ -20,21 +49,24 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
from decode_bert import _apply_style_transform
|
||||||
|
from lhotse import Fbank, load_manifest
|
||||||
from text_normalization import (
|
from text_normalization import (
|
||||||
ref_text_normalization,
|
|
||||||
remove_non_alphabetic,
|
|
||||||
upper_only_alpha,
|
|
||||||
upper_all_char,
|
|
||||||
lower_all_char,
|
lower_all_char,
|
||||||
lower_only_alpha,
|
lower_only_alpha,
|
||||||
|
ref_text_normalization,
|
||||||
|
remove_non_alphabetic,
|
||||||
train_text_normalization,
|
train_text_normalization,
|
||||||
|
upper_all_char,
|
||||||
|
upper_only_alpha,
|
||||||
)
|
)
|
||||||
from train_bert_encoder_with_style import (
|
from tqdm import tqdm
|
||||||
|
from train_bert_encoder import (
|
||||||
|
_encode_texts_as_bytes_with_tokenizer,
|
||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
get_params,
|
get_params,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
get_transducer_model,
|
get_transducer_model,
|
||||||
_encode_texts_as_bytes_with_tokenizer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -51,11 +83,12 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
@ -74,7 +107,7 @@ def get_parser():
|
|||||||
You can specify --avg to use more checkpoints for model averaging.
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
@ -83,22 +116,21 @@ def get_parser():
|
|||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'",
|
"'--epoch' and '--iter'",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless7/exp",
|
default="pruned_transducer_stateless7/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/bpe.model",
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to bpe.model.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--method",
|
"--method",
|
||||||
type=str,
|
type=str,
|
||||||
@ -110,104 +142,76 @@ def get_parser():
|
|||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
|
default="data/long_audios/long_audio.jsonl.gz",
|
||||||
help="""This is the manfiest for long audio transcription.
|
help="""This is the manfiest for long audio transcription.
|
||||||
It is intended to be sored, i.e first sort by recording ID and then sort by
|
The cust are intended to be sorted, i.e first sort by recording ID and
|
||||||
start timestamp"""
|
then sort by start timestamp""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--segment-length",
|
|
||||||
type=float,
|
|
||||||
default=30.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-pre-text",
|
"--use-pre-text",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether use pre-text when decoding the current chunk"
|
help="Whether use pre-text when decoding the current chunk",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-style-prompt",
|
"--use-style-prompt",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Use style prompt when evaluation"
|
help="Use style prompt when evaluation",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pre-text-transform",
|
"--pre-text-transform",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||||
default="mixed-punc",
|
default="mixed-punc",
|
||||||
help="The style of content prompt, i.e pre_text"
|
help="The style of content prompt, i.e pre_text",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--style-text-transform",
|
"--style-text-transform",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||||
default="mixed-punc",
|
default="mixed-punc",
|
||||||
help="The style of style prompt, i.e style_text"
|
help="The style of style prompt, i.e style_text",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-history",
|
"--num-history",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="How many previous chunks to look if using pre-text for decoding"
|
help="How many previous chunks to look if using pre-text for decoding",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-gt-pre-text",
|
"--use-gt-pre-text",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether use gt pre text when using content prompt",
|
help="Whether use gt pre text when using content prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--post-normalization",
|
"--post-normalization",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
|
||||||
"""Apply transform to a list of text. By default, the text are in
|
|
||||||
ground truth format, i.e mixed-punc.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (List[str]): Input text string
|
|
||||||
transform (str): Transform to be applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: _description_
|
|
||||||
"""
|
|
||||||
if transform == "mixed-punc":
|
|
||||||
return text
|
|
||||||
elif transform == "upper-no-punc":
|
|
||||||
return [upper_only_alpha(s) for s in text]
|
|
||||||
elif transform == "lower-no-punc":
|
|
||||||
return [lower_only_alpha(s) for s in text]
|
|
||||||
elif transform == "lower-punc":
|
|
||||||
return [lower_all_char(s) for s in text]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
@ -216,7 +220,7 @@ def main():
|
|||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
@ -226,7 +230,7 @@ def main():
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / "long_audio_transcribe"
|
params.res_dir = params.exp_dir / "long_audio_transcribe"
|
||||||
params.res_dir.mkdir(exist_ok=True)
|
params.res_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
@ -234,21 +238,22 @@ def main():
|
|||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "beam_search" in params.method:
|
if "beam_search" in params.method:
|
||||||
params.suffix += (
|
params.suffix += f"-{params.method}-beam-size-{params.beam_size}"
|
||||||
f"-{params.method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_pre_text:
|
if params.use_pre_text:
|
||||||
if params.use_gt_pre_text:
|
if params.use_gt_pre_text:
|
||||||
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
params.suffix += (
|
||||||
|
f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||||
|
)
|
||||||
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info")
|
book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "")
|
||||||
|
setup_logger(
|
||||||
|
f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info"
|
||||||
|
)
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -265,13 +270,12 @@ def main():
|
|||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
)
|
||||||
elif len(filenames) < params.avg + 1:
|
elif len(filenames) < params.avg + 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -310,22 +314,22 @@ def main():
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
# load manifest
|
# load manifest
|
||||||
manifest = load_manifest(params.manifest_dir)
|
manifest = load_manifest(params.manifest_dir)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
last_recording = ""
|
last_recording = ""
|
||||||
last_end = -1
|
last_end = -1
|
||||||
history = []
|
history = []
|
||||||
num_pre_texts = []
|
num_pre_texts = []
|
||||||
|
|
||||||
for cut in manifest:
|
for cut in manifest:
|
||||||
if cut.has_features:
|
if cut.has_features:
|
||||||
feat = cut.load_features()
|
feat = cut.load_features()
|
||||||
@ -333,45 +337,53 @@ def main():
|
|||||||
else:
|
else:
|
||||||
feat = cut.compute_features(extractor=Fbank())
|
feat = cut.compute_features(extractor=Fbank())
|
||||||
feat_lens = feat.shape[0]
|
feat_lens = feat.shape[0]
|
||||||
|
|
||||||
|
|
||||||
cur_recording = cut.recording.id
|
cur_recording = cut.recording.id
|
||||||
|
|
||||||
if cur_recording != last_recording:
|
if cur_recording != last_recording:
|
||||||
last_recording = cur_recording
|
last_recording = cur_recording
|
||||||
history = [] # clean history
|
history = [] # clean up the history
|
||||||
last_end = -1
|
last_end = -1
|
||||||
logging.info(f"Moving on to the next recording")
|
logging.info("Moving on to the next recording")
|
||||||
else:
|
else:
|
||||||
if cut.start < last_end - 0.2: # overlap exits
|
if cut.start < last_end - 0.2: # overlap with the previous cuts
|
||||||
logging.warning(f"An overlap exists between current cut and last cut")
|
logging.warning("An overlap exists between current cut and last cut")
|
||||||
logging.warning("Skipping this cut!")
|
logging.warning("Skipping this cut!")
|
||||||
continue
|
continue
|
||||||
if cut.start > last_end + 10:
|
if cut.start > last_end + 10:
|
||||||
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.")
|
logging.warning(
|
||||||
|
f"Large time gap between the current and previous utterance: {cut.start - last_end}."
|
||||||
|
)
|
||||||
|
|
||||||
# prepare input
|
# prepare input
|
||||||
x = torch.tensor(feat, device=device).unsqueeze(0)
|
x = torch.tensor(feat, device=device).unsqueeze(0)
|
||||||
x_lens = torch.tensor([feat_lens,], device=device)
|
x_lens = torch.tensor(
|
||||||
|
[
|
||||||
|
feat_lens,
|
||||||
|
],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
if params.use_pre_text:
|
if params.use_pre_text:
|
||||||
if params.num_history > 0:
|
if params.num_history > 0:
|
||||||
pre_texts = history[-params.num_history:]
|
pre_texts = history[-params.num_history :]
|
||||||
else:
|
else:
|
||||||
pre_texts = []
|
pre_texts = []
|
||||||
num_pre_texts.append(len(pre_texts))
|
num_pre_texts.append(len(pre_texts))
|
||||||
pre_texts = [train_text_normalization(" ".join(pre_texts))]
|
pre_texts = [train_text_normalization(" ".join(pre_texts))]
|
||||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||||
style_texts = [fixed_sentence]
|
style_texts = [fixed_sentence]
|
||||||
|
|
||||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
style_texts = _apply_style_transform(
|
||||||
|
style_texts, params.style_text_transform
|
||||||
# encode pre_text
|
)
|
||||||
|
|
||||||
|
# encode prompts
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer(
|
||||||
pre_texts=pre_texts,
|
pre_texts=pre_texts,
|
||||||
style_texts=style_texts,
|
style_texts=style_texts,
|
||||||
@ -380,16 +392,18 @@ def main():
|
|||||||
no_limit=True,
|
no_limit=True,
|
||||||
)
|
)
|
||||||
if params.num_history > 5:
|
if params.num_history > 5:
|
||||||
logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ")
|
logging.info(
|
||||||
|
f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} "
|
||||||
|
)
|
||||||
|
|
||||||
memory, memory_key_padding_mask = model.encode_text(
|
memory, memory_key_padding_mask = model.encode_text(
|
||||||
encoded_inputs=encoded_inputs,
|
encoded_inputs=encoded_inputs,
|
||||||
style_lens=style_lens,
|
style_lens=style_lens,
|
||||||
) # (T,B,C)
|
) # (T,B,C)
|
||||||
else:
|
else:
|
||||||
memory = None
|
memory = None
|
||||||
memory_key_padding_mask = None
|
memory_key_padding_mask = None
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
encoder_out, encoder_out_lens = model.encode_audio(
|
encoder_out, encoder_out_lens = model.encode_audio(
|
||||||
@ -398,7 +412,7 @@ def main():
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.method == "greedy_search":
|
if params.method == "greedy_search":
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
@ -412,17 +426,19 @@ def main():
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
hyp = sp.decode(hyp_tokens)[0] # in string format
|
hyp = sp.decode(hyp_tokens)[0] # in string format
|
||||||
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
|
ref_text = ref_text_normalization(
|
||||||
|
cut.supervisions[0].texts[0]
|
||||||
# extend the history, the history here is in original format
|
) # required to match the training
|
||||||
|
|
||||||
|
# extend the history
|
||||||
if params.use_gt_pre_text:
|
if params.use_gt_pre_text:
|
||||||
history.append(ref_text)
|
history.append(ref_text)
|
||||||
else:
|
else:
|
||||||
history.append(hyp)
|
history.append(hyp)
|
||||||
last_end = cut.end # update the last end timestamp
|
last_end = cut.end # update the last end timestamp
|
||||||
|
|
||||||
# append the current decoding result
|
# append the current decoding result
|
||||||
hyp = hyp.split()
|
hyp = hyp.split()
|
||||||
ref = ref_text.split()
|
ref = ref_text.split()
|
||||||
@ -431,45 +447,69 @@ def main():
|
|||||||
count += 1
|
count += 1
|
||||||
if count % 100 == 0:
|
if count % 100 == 0:
|
||||||
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
|
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
|
||||||
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}")
|
logging.info(
|
||||||
|
f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}"
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"A total of {count} cuts")
|
logging.info(f"A total of {count} cuts")
|
||||||
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}")
|
logging.info(
|
||||||
|
f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}"
|
||||||
|
)
|
||||||
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
errs_filename = (
|
||||||
|
params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
|
f,
|
||||||
|
f"long-audio-{params.method}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
if params.post_normalization:
|
if params.post_normalization:
|
||||||
params.suffix += "-post-normalization"
|
params.suffix += "-post-normalization"
|
||||||
|
|
||||||
new_res = []
|
new_res = []
|
||||||
for item in results:
|
for item in results:
|
||||||
id, ref, hyp = item
|
id, ref, hyp = item
|
||||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||||
ref = upper_only_alpha(" ".join(ref)).split()
|
ref = upper_only_alpha(" ".join(ref)).split()
|
||||||
new_res.append((id,ref,hyp))
|
new_res.append((id, ref, hyp))
|
||||||
|
|
||||||
new_res = sorted(new_res)
|
new_res = sorted(new_res)
|
||||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
recog_path = (
|
||||||
|
params.res_dir
|
||||||
|
/ f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||||
|
)
|
||||||
store_transcripts(filename=recog_path, texts=new_res)
|
store_transcripts(filename=recog_path, texts=new_res)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
errs_filename = (
|
||||||
|
params.res_dir
|
||||||
|
/ f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False,
|
f,
|
||||||
|
f"long-audio-{params.method}",
|
||||||
|
new_res,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
if __name__=="__main__":
|
|
||||||
main()
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user