From cda6e06a853f2b2358dd1acf3626130eb1f737ab Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Wed, 20 Sep 2023 10:35:37 +0800 Subject: [PATCH] updates --- .../ASR/zipformer_prompt_asr/decode_bert.py | 2 +- .../zipformer_prompt_asr/model_baseline.py | 21 +- .../zipformer_prompt_asr/model_with_BERT.py | 84 ++--- .../text_normalization.py | 37 ++- .../train_bert_encoder.py | 6 +- .../zipformer_prompt_asr/transcribe_bert.py | 308 ++++++++++-------- 6 files changed, 257 insertions(+), 201 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py index 18b3e9a14..0cd4efaed 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert.py @@ -450,6 +450,7 @@ def decode_one_batch( else: pre_texts = ["" for _ in range(batch_size)] + # get the librispeech biasing data if params.use_ls_context_list and params.use_ls_test_set: if params.biasing_level == "utterance": pre_texts = [biasing_dict[id] for id in cut_ids] @@ -476,7 +477,6 @@ def decode_one_batch( # Get the text embedding if params.use_pre_text or params.use_style_prompt: - # apply style transform to the pre_text and style_text pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) if not params.use_ls_context_list: diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py index 22733ae2a..77b4057c4 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py @@ -15,17 +15,18 @@ # limitations under the License. +import random +import warnings +from typing import Optional, Tuple + import k2 import torch import torch.nn as nn -import random -import warnings from encoder_interface import EncoderInterface +from scaling import ScaledLinear, penalize_abs_values_gt +from torch import Tensor from icefall.utils import add_sos, make_pad_mask -from scaling import penalize_abs_values_gt, ScaledLinear -from torch import Tensor -from typing import Optional, Tuple class Transducer(nn.Module): @@ -185,11 +186,6 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -257,11 +253,10 @@ class Transducer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_out, encoder_out_lens = self.encoder( - x=x, + x=x, x_lens=x_lens, src_key_padding_mask=src_key_padding_mask, ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens + return encoder_out, encoder_out_lens diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py index c7b6c7338..8c121255b 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py @@ -15,17 +15,18 @@ # limitations under the License. +import random +import warnings +from typing import Dict, Optional, Tuple + import k2 import torch import torch.nn as nn -import random -import warnings from encoder_interface import EncoderInterface +from scaling import ScaledLinear, penalize_abs_values_gt +from torch import Tensor from icefall.utils import add_sos, make_pad_mask -from scaling import penalize_abs_values_gt, ScaledLinear -from torch import Tensor -from typing import Optional, Tuple, Dict class PromptedTransducer(nn.Module): @@ -97,13 +98,21 @@ class PromptedTransducer(nn.Module): vocab_size, initial_scale=0.25, ) - - self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT + + self.use_BERT = use_BERT # if the text encoder is a pre-trained BERT self.context_fuser = context_fuser - - assert text_encoder_type in ("BERT","DistilBERT", "BERT-UNCASED"), f"Unseen text_encoder type {text_encoder_type}" - self.text_encoder_dim = self.text_encoder.config.hidden_size if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim - + + assert text_encoder_type in ( + "BERT", + "DistilBERT", + "BERT-UNCASED", + ), f"Unseen text_encoder type {text_encoder_type}" + self.text_encoder_dim = ( + self.text_encoder.config.hidden_size + if text_encoder_type in ("BERT", "BERT-UNCASED") + else self.text_encoder.config.dim + ) + if text_encoder_adapter: self.text_encoder_adapter = nn.Sequential( nn.Linear(self.text_encoder_dim, self.text_encoder_dim, bias=False), @@ -111,8 +120,10 @@ class PromptedTransducer(nn.Module): ) else: self.text_encoder_adapter = None - - self.style_prompt_embedding = nn.Parameter(torch.full((self.text_encoder_dim,), 0.5)) + + self.style_prompt_embedding = nn.Parameter( + torch.full((self.text_encoder_dim,), 0.5) + ) def forward( self, @@ -181,11 +192,10 @@ class PromptedTransducer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # freeze the BERT text encoder - + if use_pre_text: memory, memory_key_padding_mask = self.encode_text( - encoded_inputs, - style_lens=style_lens + encoded_inputs, style_lens=style_lens ) else: memory = None @@ -231,11 +241,6 @@ class PromptedTransducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), @@ -270,12 +275,12 @@ class PromptedTransducer(nn.Module): # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). if self.context_fuser is not None and memory is not None: - memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C) + memory = memory.permute(1, 0, 2) # (T,N,C) -> (N,T,C) context = self.context_fuser(memory, padding_mask=memory_key_padding_mask) - context = self.joiner.context_proj(context) + context = self.joiner.context_proj(context) else: context = None - + logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False) with torch.cuda.amp.autocast(enabled=False): @@ -304,16 +309,17 @@ class PromptedTransducer(nn.Module): (memory_len, batch_size, embed_dim) = memory.shape indicator = ( - torch.arange(memory_len, device=memory.device).unsqueeze(-1) - < style_lens + torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens ) indicator = indicator.to(memory.dtype) extra_term = torch.zeros_like(memory) - extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand(memory_len, batch_size, self.text_encoder_dim) + extra_term += indicator.unsqueeze(-1) * self.style_prompt_embedding.expand( + memory_len, batch_size, self.text_encoder_dim + ) return memory + extra_term - + def encode_text( self, encoded_inputs: Dict, @@ -326,25 +332,25 @@ class PromptedTransducer(nn.Module): Returns: Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the - text_encoder and the attention mask + text_encoder and the attention mask """ - text_lens = encoded_inputs.pop("length") # need to use pop to remove this item - + text_lens = encoded_inputs.pop("length") # need to use pop to remove this item + # Freeze the pre-trained text encoder with torch.no_grad(): - memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C) - memory = memory.permute(1,0,2) - + memory = self.text_encoder(**encoded_inputs)["last_hidden_state"] # (B,T,C) + memory = memory.permute(1, 0, 2) + # Text encoder adapter if self.text_encoder_adapter is not None: memory = self.text_encoder_adapter(memory) - + memory = self._add_style_indicator(memory, style_lens) memory_key_padding_mask = make_pad_mask(text_lens) - + return memory, memory_key_padding_mask - + def encode_audio( self, feature: Tensor, @@ -368,14 +374,14 @@ class PromptedTransducer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_out, encoder_out_lens = self.encoder( - x=x, + x=x, x_lens=x_lens, src_key_padding_mask=src_key_padding_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - + return encoder_out, encoder_out_lens diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py index 024c444f1..657089f46 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/text_normalization.py @@ -1,12 +1,29 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re def train_text_normalization(s: str) -> str: + # replace full-width with half-width s = s.replace("“", '"') s = s.replace("”", '"') s = s.replace("‘", "'") s = s.replace("’", "'") - if s[:2] == "\" ": # remove the starting double quote + if s[:2] == '" ': # remove the starting double quote s = s[2:] return s @@ -17,42 +34,42 @@ def ref_text_normalization(ref_text: str) -> str: p = r"[FN#[0-9]*]" pattern = re.compile(p) - # ref_text = ref_text.replace("”", "\"") - # ref_text = ref_text.replace("’", "'") res = pattern.findall(ref_text) ref_text = re.sub(p, "", ref_text) - + ref_text = train_text_normalization(ref_text) return ref_text -def remove_non_alphabetic(text: str, strict: bool=True) -> str: +def remove_non_alphabetic(text: str, strict: bool = True) -> str: + # Recommend to set strict to False if not strict: # Note, this also keeps space, single quote(') and hypen (-) text = text.replace("-", " ") text = text.replace("—", " ") - return re.sub("[^a-zA-Z0-9\s']+", "", text) + return re.sub(r"[^a-zA-Z0-9\s']+", "", text) else: # only keeps space - return re.sub("[^a-zA-Z\s]+", "", text) + return re.sub(r"[^a-zA-Z\s]+", "", text) -def recog_text_normalization(recog_text: str) -> str: - pass - def upper_only_alpha(text: str) -> str: return remove_non_alphabetic(text.upper(), strict=False) + def lower_only_alpha(text: str) -> str: return remove_non_alphabetic(text.lower(), strict=False) + def lower_all_char(text: str) -> str: return text.lower() + def upper_all_char(text: str) -> str: return text.upper() + if __name__ == "__main__": ref_text = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." print(ref_text) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index bde640fb6..56ed27a6a 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang, +# # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py b/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py index 461810d3c..ef0c48e8a 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_bert.py @@ -1,18 +1,47 @@ +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +python ./zipformer_prompt_asr/transcribe_bert.py \ + --epoch 50 \ + --avg 10 \ + --exp-dir ./zipformer_prompt_asr/exp \ + --manifest-dir data/long_audios/long_audio.jsonl.gz \ + --pre-text-transform mixed-punc \ + --style-text-transform mixed-punc \ + --num-history 5 \ + --use-pre-text True \ + --use-gt-pre-text False + + +""" + import argparse import logging import math import warnings from pathlib import Path from typing import List -from tqdm import tqdm import k2 import kaldifeat import sentencepiece as spm import torch import torchaudio -from lhotse import load_manifest, Fbank - from beam_search import ( beam_search, fast_beam_search_one_best, @@ -20,21 +49,24 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) +from decode_bert import _apply_style_transform +from lhotse import Fbank, load_manifest from text_normalization import ( - ref_text_normalization, - remove_non_alphabetic, - upper_only_alpha, - upper_all_char, lower_all_char, lower_only_alpha, + ref_text_normalization, + remove_non_alphabetic, train_text_normalization, + upper_all_char, + upper_only_alpha, ) -from train_bert_encoder_with_style import ( +from tqdm import tqdm +from train_bert_encoder import ( + _encode_texts_as_bytes_with_tokenizer, add_model_arguments, get_params, get_tokenizer, get_transducer_model, - _encode_texts_as_bytes_with_tokenizer, ) from icefall.checkpoint import ( @@ -51,11 +83,12 @@ from icefall.utils import ( write_error_stats, ) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - + parser.add_argument( "--epoch", type=int, @@ -74,7 +107,7 @@ def get_parser(): You can specify --avg to use more checkpoints for model averaging. """, ) - + parser.add_argument( "--avg", type=int, @@ -83,22 +116,21 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch' and '--iter'", ) - + parser.add_argument( "--exp-dir", type=str, default="pruned_transducer_stateless7/exp", help="The experiment dir", ) - - + parser.add_argument( "--bpe-model", type=str, default="data/lang_bpe_500/bpe.model", help="""Path to bpe.model.""", ) - + parser.add_argument( "--method", type=str, @@ -110,104 +142,76 @@ def get_parser(): - fast_beam_search """, ) - + parser.add_argument( "--beam-size", type=int, default=4, ) - + parser.add_argument( "--manifest-dir", type=str, - default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz", - help="""This is the manfiest for long audio transcription. - It is intended to be sored, i.e first sort by recording ID and then sort by - start timestamp""" + default="data/long_audios/long_audio.jsonl.gz", + help="""This is the manfiest for long audio transcription. + The cust are intended to be sorted, i.e first sort by recording ID and + then sort by start timestamp""", ) - - parser.add_argument( - "--segment-length", - type=float, - default=30.0, - ) - + parser.add_argument( "--use-pre-text", type=str2bool, default=False, - help="Whether use pre-text when decoding the current chunk" + help="Whether use pre-text when decoding the current chunk", ) - + parser.add_argument( "--use-style-prompt", type=str2bool, default=True, - help="Use style prompt when evaluation" + help="Use style prompt when evaluation", ) - + parser.add_argument( "--pre-text-transform", type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], default="mixed-punc", - help="The style of content prompt, i.e pre_text" + help="The style of content prompt, i.e pre_text", ) - + parser.add_argument( "--style-text-transform", type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], + choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"], default="mixed-punc", - help="The style of style prompt, i.e style_text" + help="The style of style prompt, i.e style_text", ) - + parser.add_argument( "--num-history", type=int, default=2, - help="How many previous chunks to look if using pre-text for decoding" + help="How many previous chunks to look if using pre-text for decoding", ) - + parser.add_argument( "--use-gt-pre-text", type=str2bool, default=False, help="Whether use gt pre text when using content prompt", ) - + parser.add_argument( "--post-normalization", type=str2bool, default=True, ) - + add_model_arguments(parser) - + return parser - -def _apply_style_transform(text: List[str], transform: str) -> List[str]: - """Apply transform to a list of text. By default, the text are in - ground truth format, i.e mixed-punc. - Args: - text (List[str]): Input text string - transform (str): Transform to be applied - - Returns: - List[str]: _description_ - """ - if transform == "mixed-punc": - return text - elif transform == "upper-no-punc": - return [upper_only_alpha(s) for s in text] - elif transform == "lower-no-punc": - return [lower_only_alpha(s) for s in text] - elif transform == "lower-punc": - return [lower_all_char(s) for s in text] - else: - raise NotImplementedError(f"Unseen transform: {transform}") - @torch.no_grad() def main(): @@ -216,7 +220,7 @@ def main(): args.exp_dir = Path(args.exp_dir) params = get_params() - + params.update(vars(args)) sp = spm.SentencePieceProcessor() @@ -226,7 +230,7 @@ def main(): params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - + params.res_dir = params.exp_dir / "long_audio_transcribe" params.res_dir.mkdir(exist_ok=True) @@ -234,21 +238,22 @@ def main(): params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - + if "beam_search" in params.method: - params.suffix += ( - f"-{params.method}-beam-size-{params.beam_size}" - ) - + params.suffix += f"-{params.method}-beam-size-{params.beam_size}" + if params.use_pre_text: if params.use_gt_pre_text: params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}" else: - params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}" - - - book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "") - setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info") + params.suffix += ( + f"-pre-text-{params.pre_text_transform}-history-{params.num_history}" + ) + + book_name = params.manifest_dir.split("/")[-1].replace(".jsonl.gz", "") + setup_logger( + f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info" + ) logging.info("Decoding started") device = torch.device("cpu") @@ -265,13 +270,12 @@ def main(): logging.info(f"Number of model parameters: {num_param}") if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -310,22 +314,22 @@ def main(): device=device, ) ) - + model.to(device) model.eval() model.device = device - + # load manifest manifest = load_manifest(params.manifest_dir) results = [] count = 0 - + last_recording = "" last_end = -1 history = [] num_pre_texts = [] - + for cut in manifest: if cut.has_features: feat = cut.load_features() @@ -333,45 +337,53 @@ def main(): else: feat = cut.compute_features(extractor=Fbank()) feat_lens = feat.shape[0] - - + cur_recording = cut.recording.id - + if cur_recording != last_recording: last_recording = cur_recording - history = [] # clean history + history = [] # clean up the history last_end = -1 - logging.info(f"Moving on to the next recording") + logging.info("Moving on to the next recording") else: - if cut.start < last_end - 0.2: # overlap exits - logging.warning(f"An overlap exists between current cut and last cut") + if cut.start < last_end - 0.2: # overlap with the previous cuts + logging.warning("An overlap exists between current cut and last cut") logging.warning("Skipping this cut!") continue if cut.start > last_end + 10: - logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.") - + logging.warning( + f"Large time gap between the current and previous utterance: {cut.start - last_end}." + ) + # prepare input x = torch.tensor(feat, device=device).unsqueeze(0) - x_lens = torch.tensor([feat_lens,], device=device) - + x_lens = torch.tensor( + [ + feat_lens, + ], + device=device, + ) + if params.use_pre_text: if params.num_history > 0: - pre_texts = history[-params.num_history:] + pre_texts = history[-params.num_history :] else: pre_texts = [] num_pre_texts.append(len(pre_texts)) pre_texts = [train_text_normalization(" ".join(pre_texts))] fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." style_texts = [fixed_sentence] - + pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) if params.use_style_prompt: - style_texts = _apply_style_transform(style_texts, params.style_text_transform) - - # encode pre_text + style_texts = _apply_style_transform( + style_texts, params.style_text_transform + ) + + # encode prompts with warnings.catch_warnings(): warnings.simplefilter("ignore") - + encoded_inputs, style_lens = _encode_texts_as_bytes_with_tokenizer( pre_texts=pre_texts, style_texts=style_texts, @@ -380,16 +392,18 @@ def main(): no_limit=True, ) if params.num_history > 5: - logging.info(f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} ") - + logging.info( + f"Shape of encoded texts: {encoded_inputs['input_ids'].shape} " + ) + memory, memory_key_padding_mask = model.encode_text( encoded_inputs=encoded_inputs, style_lens=style_lens, - ) # (T,B,C) + ) # (T,B,C) else: memory = None memory_key_padding_mask = None - + with warnings.catch_warnings(): warnings.simplefilter("ignore") encoder_out, encoder_out_lens = model.encode_audio( @@ -398,7 +412,7 @@ def main(): memory=memory, memory_key_padding_mask=memory_key_padding_mask, ) - + if params.method == "greedy_search": hyp_tokens = greedy_search_batch( model=model, @@ -412,17 +426,19 @@ def main(): encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) - - hyp = sp.decode(hyp_tokens)[0] # in string format - ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training - - # extend the history, the history here is in original format + + hyp = sp.decode(hyp_tokens)[0] # in string format + ref_text = ref_text_normalization( + cut.supervisions[0].texts[0] + ) # required to match the training + + # extend the history if params.use_gt_pre_text: - history.append(ref_text) + history.append(ref_text) else: history.append(hyp) - last_end = cut.end # update the last end timestamp - + last_end = cut.end # update the last end timestamp + # append the current decoding result hyp = hyp.split() ref = ref_text.split() @@ -431,45 +447,69 @@ def main(): count += 1 if count % 100 == 0: logging.info(f"Cuts processed until now: {count}/{len(manifest)}") - logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}") - + logging.info( + f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}" + ) + logging.info(f"A total of {count} cuts") - logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}") - + logging.info( + f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}" + ) + results = sorted(results) - recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt" + recog_path = ( + params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt" + ) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") - - errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt" + + errs_filename = ( + params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt" + ) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False, + f, + f"long-audio-{params.method}", + results, + enable_log=True, + compute_CER=False, ) logging.info("Wrote detailed error stats to {}".format(errs_filename)) if params.post_normalization: params.suffix += "-post-normalization" - + new_res = [] for item in results: id, ref, hyp = item hyp = upper_only_alpha(" ".join(hyp)).split() - ref = upper_only_alpha(" ".join(ref)).split() - new_res.append((id,ref,hyp)) - + ref = upper_only_alpha(" ".join(ref)).split() + new_res.append((id, ref, hyp)) + new_res = sorted(new_res) - recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" + recog_path = ( + params.res_dir + / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" + ) store_transcripts(filename=recog_path, texts=new_res) logging.info(f"The transcripts are stored in {recog_path}") - - errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" + + errs_filename = ( + params.res_dir + / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" + ) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False, + f, + f"long-audio-{params.method}", + new_res, + enable_log=True, + compute_CER=False, ) logging.info("Wrote detailed error stats to {}".format(errs_filename)) -if __name__=="__main__": - main() \ No newline at end of file + + +if __name__ == "__main__": + main()