mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
fix lint
This commit is contained in:
parent
b970ba569a
commit
838bf223f3
@ -57,8 +57,8 @@ from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
@ -297,6 +297,7 @@ def decode_one_batch(
|
||||
print(hyps)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
@ -314,6 +315,7 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
|
||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
@ -323,6 +325,7 @@ def decode_dataset(
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
@ -348,6 +351,7 @@ def decode_dataset(
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
2
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Executable file → Normal file
2
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Executable file → Normal file
@ -244,4 +244,4 @@ class MultiDataset:
|
||||
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
||||
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
||||
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
||||
}
|
||||
}
|
||||
|
@ -65,8 +65,8 @@ from torch.cuda.amp import GradScaler
|
||||
from torch.nn.functional import pad as pad_tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
@ -458,6 +458,7 @@ def compute_loss(
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import whisper
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from typing import Dict, Iterable, Optional
|
||||
from whisper.model import ResidualAttentionBlock, LayerNorm
|
||||
import numpy as np
|
||||
from torch import Tensor, nn
|
||||
from whisper.model import LayerNorm, ResidualAttentionBlock
|
||||
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
@ -19,10 +20,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = (
|
||||
x
|
||||
+ self.positional_embedding[offset : offset + x.shape[1]]
|
||||
)
|
||||
x = x + self.positional_embedding[offset : offset + x.shape[1]]
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
# for block in self.blocks:
|
||||
@ -39,6 +37,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def replace_whisper_decoder_forward():
|
||||
"""
|
||||
This function monkey patches the forward method of the whisper encoder.
|
||||
|
@ -20,7 +20,7 @@
|
||||
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
|
||||
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 |
|
||||
|
||||
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
|
||||
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
|
||||
|
||||
<details><summary> Detail all models </summary><p>
|
||||
|
||||
|
@ -29,9 +29,10 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import kaldialign
|
||||
from speechio_norm import TextNorm
|
||||
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
from speechio_norm import TextNorm
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -140,6 +141,7 @@ def get_filenames(
|
||||
results.append(whisper_filename)
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -14,10 +14,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -49,6 +51,7 @@ def get_parser():
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_fixed_text(fixed_text_path):
|
||||
"""
|
||||
fixed text format
|
||||
@ -57,32 +60,38 @@ def load_fixed_text(fixed_text_path):
|
||||
load into a dict
|
||||
"""
|
||||
fixed_text_dict = {}
|
||||
with open(fixed_text_path, 'r') as f:
|
||||
with open(fixed_text_path, "r") as f:
|
||||
for line in f:
|
||||
cut_id, text = line.strip().split(' ', 1)
|
||||
cut_id, text = line.strip().split(" ", 1)
|
||||
fixed_text_dict[cut_id] = text
|
||||
return fixed_text_dict
|
||||
|
||||
|
||||
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
|
||||
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
|
||||
fixed_item = 0
|
||||
for i, cut in enumerate(manifest):
|
||||
if i % 10000 == 0:
|
||||
logging.info(f'Processing cut {i}, fixed {fixed_item}')
|
||||
logging.info(f"Processing cut {i}, fixed {fixed_item}")
|
||||
cut_id_orgin = cut.id
|
||||
if cut_id_orgin.endswith('_sp0.9'):
|
||||
if cut_id_orgin.endswith("_sp0.9"):
|
||||
cut_id = cut_id_orgin[:-6]
|
||||
elif cut_id_orgin.endswith('_sp1.1'):
|
||||
elif cut_id_orgin.endswith("_sp1.1"):
|
||||
cut_id = cut_id_orgin[:-6]
|
||||
else:
|
||||
cut_id = cut_id_orgin
|
||||
if cut_id in fixed_text_dict:
|
||||
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
|
||||
assert (
|
||||
len(cut.supervisions) == 1
|
||||
), f"cut {cut_id} has {len(cut.supervisions)} supervisions"
|
||||
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
|
||||
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
|
||||
logging.info(
|
||||
f"Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}"
|
||||
)
|
||||
cut.supervisions[0].text = fixed_text_dict[cut_id]
|
||||
fixed_item += 1
|
||||
manifest_writer.write(cut)
|
||||
manifest_writer.write(cut)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
@ -92,23 +101,26 @@ def main():
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
fixed_text_path = args.manifest_dir + 'text.fix'
|
||||
fixed_text_path = args.manifest_dir + "text.fix"
|
||||
fixed_text_dict = load_fixed_text(fixed_text_path)
|
||||
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
|
||||
logging.info(f"Loaded {len(fixed_text_dict)} fixed texts")
|
||||
|
||||
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
|
||||
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
|
||||
logging.info(f'Loading dev manifest from {dev_manifest_path}')
|
||||
dev_manifest_path = args.manifest_dir + "cuts_DEV.jsonl.gz"
|
||||
fixed_dev_manifest_path = args.manifest_dir + "cuts_DEV_fixed.jsonl.gz"
|
||||
logging.info(f"Loading dev manifest from {dev_manifest_path}")
|
||||
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
|
||||
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
|
||||
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
|
||||
logging.info(f"Fixed dev manifest saved to {fixed_dev_manifest_path}")
|
||||
|
||||
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
|
||||
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
|
||||
logging.info(f'Loading manifest from {manifest_path}')
|
||||
manifest_path = args.manifest_dir + f"cuts_{args.training_subset}.jsonl.gz"
|
||||
fixed_manifest_path = (
|
||||
args.manifest_dir + f"cuts_{args.training_subset}_fixed.jsonl.gz"
|
||||
)
|
||||
logging.info(f"Loading manifest from {manifest_path}")
|
||||
cuts_manifest = load_manifest_lazy(manifest_path)
|
||||
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
|
||||
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
|
||||
logging.info(f"Fixed training manifest saved to {fixed_manifest_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
@ -424,4 +424,4 @@ if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then
|
||||
python local/fix_manifest.py \
|
||||
--fixed-transcript-path data/fbank/text.fix \
|
||||
--training-subset L
|
||||
fi
|
||||
fi
|
||||
|
@ -38,13 +38,13 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import os
|
||||
import deepspeed
|
||||
import k2
|
||||
import optim
|
||||
|
@ -25,7 +25,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "You need to run the prepare.sh first."
|
||||
exit -1
|
||||
fi
|
||||
|
||||
|
||||
python ./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--exp-dir zipformer/exp \
|
||||
@ -105,11 +105,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
--encoder-dim 128,128,128,128,128,128 \
|
||||
--encoder-unmasked-dim 128,128,128,128,128,128 \
|
||||
--causal 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 2: Finetune the model"
|
||||
|
||||
|
||||
# The following configuration of lr schedule should work well
|
||||
# You may also tune the following parameters to adjust learning rate schedule
|
||||
base_lr=0.0005
|
||||
@ -201,4 +201,4 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
--encoder-dim 128,128,128,128,128,128 \
|
||||
--encoder-unmasked-dim 128,128,128,128,128,128 \
|
||||
--causal 1
|
||||
fi
|
||||
fi
|
||||
|
Loading…
x
Reference in New Issue
Block a user