This commit is contained in:
root 2024-04-24 11:29:35 +00:00
parent b970ba569a
commit 838bf223f3
11 changed files with 738 additions and 559 deletions

View File

@ -57,8 +57,8 @@ from lhotse.cut import Cut
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert from zhconv import convert
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
@ -297,6 +297,7 @@ def decode_one_batch(
print(hyps) print(hyps)
return {"beam-search": hyps} return {"beam-search": hyps}
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
@ -314,6 +315,7 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "beam-search". Return a dict, whose key may be "beam-search".
""" """
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
""" """
Text normalization similar to M2MeT challenge baseline. Text normalization similar to M2MeT challenge baseline.
@ -323,6 +325,7 @@ def decode_dataset(
return text return text
elif normalize == "m2met": elif normalize == "m2met":
import re import re
text = text.replace(" ", "") text = text.replace(" ", "")
text = text.replace("<sil>", "") text = text.replace("<sil>", "")
text = text.replace("<%>", "") text = text.replace("<%>", "")
@ -348,6 +351,7 @@ def decode_dataset(
text = text.replace("", "") text = text.replace("", "")
text = text.replace("", "") text = text.replace("", "")
return text return text
results = [] results = []
num_cuts = 0 num_cuts = 0

2
egs/multi_zh-hans/ASR/whisper/multi_dataset.py Executable file → Normal file
View File

@ -244,4 +244,4 @@ class MultiDataset:
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts, # "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts, # "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts, # "wenetspeech_dev": wenetspeech_dev_cuts,
} }

View File

@ -65,8 +65,8 @@ from torch.cuda.amp import GradScaler
from torch.nn.functional import pad as pad_tensor from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -458,6 +458,7 @@ def compute_loss(
return text return text
elif normalize == "m2met": elif normalize == "m2met":
import re import re
text = text.replace(" ", "") text = text.replace(" ", "")
text = text.replace("<sil>", "") text = text.replace("<sil>", "")
text = text.replace("<%>", "") text = text.replace("<%>", "")

View File

@ -1,11 +1,12 @@
from typing import Dict, Iterable, Optional
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import whisper import whisper
from torch import Tensor from torch import Tensor, nn
from torch import nn from whisper.model import LayerNorm, ResidualAttentionBlock
from typing import Dict, Iterable, Optional
from whisper.model import ResidualAttentionBlock, LayerNorm
import numpy as np
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
""" """
@ -19,10 +20,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
self.token_embedding(x) self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]] + self.positional_embedding[offset : offset + x.shape[-1]]
) )
x = ( x = x + self.positional_embedding[offset : offset + x.shape[1]]
x
+ self.positional_embedding[offset : offset + x.shape[1]]
)
x = x.to(xa.dtype) x = x.to(xa.dtype)
# for block in self.blocks: # for block in self.blocks:
@ -39,6 +37,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
return logits return logits
def replace_whisper_decoder_forward(): def replace_whisper_decoder_forward():
""" """
This function monkey patches the forward method of the whisper encoder. This function monkey patches the forward method of the whisper encoder.

View File

@ -20,7 +20,7 @@
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 | | 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 | | 11 | baidu_pro_api_zh | 7.29% | 2023.12 |
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67) Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
<details><summary> Detail all models </summary><p> <details><summary> Detail all models </summary><p>

View File

@ -29,9 +29,10 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import kaldialign import kaldialign
from speechio_norm import TextNorm
from icefall.utils import store_transcripts, write_error_stats from icefall.utils import store_transcripts, write_error_stats
from speechio_norm import TextNorm
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -140,6 +141,7 @@ def get_filenames(
results.append(whisper_filename) results.append(whisper_filename)
return results return results
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()

File diff suppressed because it is too large Load Diff

View File

@ -14,10 +14,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import argparse import argparse
import logging
from lhotse import CutSet, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -49,6 +51,7 @@ def get_parser():
return parser return parser
def load_fixed_text(fixed_text_path): def load_fixed_text(fixed_text_path):
""" """
fixed text format fixed text format
@ -57,32 +60,38 @@ def load_fixed_text(fixed_text_path):
load into a dict load into a dict
""" """
fixed_text_dict = {} fixed_text_dict = {}
with open(fixed_text_path, 'r') as f: with open(fixed_text_path, "r") as f:
for line in f: for line in f:
cut_id, text = line.strip().split(' ', 1) cut_id, text = line.strip().split(" ", 1)
fixed_text_dict[cut_id] = text fixed_text_dict[cut_id] = text
return fixed_text_dict return fixed_text_dict
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path): def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
with CutSet.open_writer(fixed_manifest_path) as manifest_writer: with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
fixed_item = 0 fixed_item = 0
for i, cut in enumerate(manifest): for i, cut in enumerate(manifest):
if i % 10000 == 0: if i % 10000 == 0:
logging.info(f'Processing cut {i}, fixed {fixed_item}') logging.info(f"Processing cut {i}, fixed {fixed_item}")
cut_id_orgin = cut.id cut_id_orgin = cut.id
if cut_id_orgin.endswith('_sp0.9'): if cut_id_orgin.endswith("_sp0.9"):
cut_id = cut_id_orgin[:-6] cut_id = cut_id_orgin[:-6]
elif cut_id_orgin.endswith('_sp1.1'): elif cut_id_orgin.endswith("_sp1.1"):
cut_id = cut_id_orgin[:-6] cut_id = cut_id_orgin[:-6]
else: else:
cut_id = cut_id_orgin cut_id = cut_id_orgin
if cut_id in fixed_text_dict: if cut_id in fixed_text_dict:
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions' assert (
len(cut.supervisions) == 1
), f"cut {cut_id} has {len(cut.supervisions)} supervisions"
if cut.supervisions[0].text != fixed_text_dict[cut_id]: if cut.supervisions[0].text != fixed_text_dict[cut_id]:
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}') logging.info(
f"Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}"
)
cut.supervisions[0].text = fixed_text_dict[cut_id] cut.supervisions[0].text = fixed_text_dict[cut_id]
fixed_item += 1 fixed_item += 1
manifest_writer.write(cut) manifest_writer.write(cut)
def main(): def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
@ -92,23 +101,26 @@ def main():
args = parser.parse_args() args = parser.parse_args()
logging.info(vars(args)) logging.info(vars(args))
fixed_text_path = args.manifest_dir + 'text.fix' fixed_text_path = args.manifest_dir + "text.fix"
fixed_text_dict = load_fixed_text(fixed_text_path) fixed_text_dict = load_fixed_text(fixed_text_path)
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts') logging.info(f"Loaded {len(fixed_text_dict)} fixed texts")
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz' dev_manifest_path = args.manifest_dir + "cuts_DEV.jsonl.gz"
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz' fixed_dev_manifest_path = args.manifest_dir + "cuts_DEV_fixed.jsonl.gz"
logging.info(f'Loading dev manifest from {dev_manifest_path}') logging.info(f"Loading dev manifest from {dev_manifest_path}")
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path) cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path) fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}') logging.info(f"Fixed dev manifest saved to {fixed_dev_manifest_path}")
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz' manifest_path = args.manifest_dir + f"cuts_{args.training_subset}.jsonl.gz"
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz' fixed_manifest_path = (
logging.info(f'Loading manifest from {manifest_path}') args.manifest_dir + f"cuts_{args.training_subset}_fixed.jsonl.gz"
)
logging.info(f"Loading manifest from {manifest_path}")
cuts_manifest = load_manifest_lazy(manifest_path) cuts_manifest = load_manifest_lazy(manifest_path)
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path) fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}') logging.info(f"Fixed training manifest saved to {fixed_manifest_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -424,4 +424,4 @@ if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then
python local/fix_manifest.py \ python local/fix_manifest.py \
--fixed-transcript-path data/fbank/text.fix \ --fixed-transcript-path data/fbank/text.fix \
--training-subset L --training-subset L
fi fi

View File

@ -38,13 +38,13 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
import argparse import argparse
import copy import copy
import logging import logging
import os
import random import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import os
import deepspeed import deepspeed
import k2 import k2
import optim import optim

View File

@ -25,7 +25,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "You need to run the prepare.sh first." log "You need to run the prepare.sh first."
exit -1 exit -1
fi fi
python ./zipformer/train.py \ python ./zipformer/train.py \
--world-size 4 \ --world-size 4 \
--exp-dir zipformer/exp \ --exp-dir zipformer/exp \
@ -105,11 +105,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
--encoder-dim 128,128,128,128,128,128 \ --encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1 --causal 1
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 2: Finetune the model" log "Stage 2: Finetune the model"
# The following configuration of lr schedule should work well # The following configuration of lr schedule should work well
# You may also tune the following parameters to adjust learning rate schedule # You may also tune the following parameters to adjust learning rate schedule
base_lr=0.0005 base_lr=0.0005
@ -201,4 +201,4 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--encoder-dim 128,128,128,128,128,128 \ --encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1 --causal 1
fi fi