This commit is contained in:
root 2024-04-24 11:29:35 +00:00
parent b970ba569a
commit 838bf223f3
11 changed files with 738 additions and 559 deletions

View File

@ -57,8 +57,8 @@ from lhotse.cut import Cut
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
@ -297,6 +297,7 @@ def decode_one_batch(
print(hyps)
return {"beam-search": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
@ -314,6 +315,7 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
@ -323,6 +325,7 @@ def decode_dataset(
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
@ -348,6 +351,7 @@ def decode_dataset(
text = text.replace("", "")
text = text.replace("", "")
return text
results = []
num_cuts = 0

0
egs/multi_zh-hans/ASR/whisper/multi_dataset.py Executable file → Normal file
View File

View File

@ -65,8 +65,8 @@ from torch.cuda.amp import GradScaler
from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -458,6 +458,7 @@ def compute_loss(
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")

View File

@ -1,11 +1,12 @@
from typing import Dict, Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
import whisper
from torch import Tensor
from torch import nn
from typing import Dict, Iterable, Optional
from whisper.model import ResidualAttentionBlock, LayerNorm
import numpy as np
from torch import Tensor, nn
from whisper.model import LayerNorm, ResidualAttentionBlock
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
@ -19,10 +20,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = (
x
+ self.positional_embedding[offset : offset + x.shape[1]]
)
x = x + self.positional_embedding[offset : offset + x.shape[1]]
x = x.to(xa.dtype)
# for block in self.blocks:
@ -39,6 +37,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
return logits
def replace_whisper_decoder_forward():
"""
This function monkey patches the forward method of the whisper encoder.

View File

@ -29,9 +29,10 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import kaldialign
from speechio_norm import TextNorm
from icefall.utils import store_transcripts, write_error_stats
from speechio_norm import TextNorm
def get_parser():
parser = argparse.ArgumentParser(
@ -140,6 +141,7 @@ def get_filenames(
results.append(whisper_filename)
return results
def main():
parser = get_parser()
args = parser.parse_args()

File diff suppressed because it is too large Load Diff

View File

@ -14,10 +14,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import argparse
import logging
from lhotse import CutSet, load_manifest_lazy
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -49,6 +51,7 @@ def get_parser():
return parser
def load_fixed_text(fixed_text_path):
"""
fixed text format
@ -57,33 +60,39 @@ def load_fixed_text(fixed_text_path):
load into a dict
"""
fixed_text_dict = {}
with open(fixed_text_path, 'r') as f:
with open(fixed_text_path, "r") as f:
for line in f:
cut_id, text = line.strip().split(' ', 1)
cut_id, text = line.strip().split(" ", 1)
fixed_text_dict[cut_id] = text
return fixed_text_dict
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
fixed_item = 0
for i, cut in enumerate(manifest):
if i % 10000 == 0:
logging.info(f'Processing cut {i}, fixed {fixed_item}')
logging.info(f"Processing cut {i}, fixed {fixed_item}")
cut_id_orgin = cut.id
if cut_id_orgin.endswith('_sp0.9'):
if cut_id_orgin.endswith("_sp0.9"):
cut_id = cut_id_orgin[:-6]
elif cut_id_orgin.endswith('_sp1.1'):
elif cut_id_orgin.endswith("_sp1.1"):
cut_id = cut_id_orgin[:-6]
else:
cut_id = cut_id_orgin
if cut_id in fixed_text_dict:
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
assert (
len(cut.supervisions) == 1
), f"cut {cut_id} has {len(cut.supervisions)} supervisions"
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
logging.info(
f"Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}"
)
cut.supervisions[0].text = fixed_text_dict[cut_id]
fixed_item += 1
manifest_writer.write(cut)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
@ -92,23 +101,26 @@ def main():
args = parser.parse_args()
logging.info(vars(args))
fixed_text_path = args.manifest_dir + 'text.fix'
fixed_text_path = args.manifest_dir + "text.fix"
fixed_text_dict = load_fixed_text(fixed_text_path)
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
logging.info(f"Loaded {len(fixed_text_dict)} fixed texts")
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
logging.info(f'Loading dev manifest from {dev_manifest_path}')
dev_manifest_path = args.manifest_dir + "cuts_DEV.jsonl.gz"
fixed_dev_manifest_path = args.manifest_dir + "cuts_DEV_fixed.jsonl.gz"
logging.info(f"Loading dev manifest from {dev_manifest_path}")
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
logging.info(f"Fixed dev manifest saved to {fixed_dev_manifest_path}")
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
logging.info(f'Loading manifest from {manifest_path}')
manifest_path = args.manifest_dir + f"cuts_{args.training_subset}.jsonl.gz"
fixed_manifest_path = (
args.manifest_dir + f"cuts_{args.training_subset}_fixed.jsonl.gz"
)
logging.info(f"Loading manifest from {manifest_path}")
cuts_manifest = load_manifest_lazy(manifest_path)
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
logging.info(f"Fixed training manifest saved to {fixed_manifest_path}")
if __name__ == "__main__":
main()

View File

@ -38,13 +38,13 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import os
import deepspeed
import k2
import optim