mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix lint
This commit is contained in:
parent
b970ba569a
commit
838bf223f3
@ -57,8 +57,8 @@ from lhotse.cut import Cut
|
|||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
from tn.chinese.normalizer import Normalizer
|
from tn.chinese.normalizer import Normalizer
|
||||||
from whisper.normalizers import BasicTextNormalizer
|
from whisper.normalizers import BasicTextNormalizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
|
||||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||||
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from zhconv import convert
|
from zhconv import convert
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||||
@ -297,6 +297,7 @@ def decode_one_batch(
|
|||||||
print(hyps)
|
print(hyps)
|
||||||
return {"beam-search": hyps}
|
return {"beam-search": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -314,6 +315,7 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "beam-search".
|
Return a dict, whose key may be "beam-search".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||||
"""
|
"""
|
||||||
Text normalization similar to M2MeT challenge baseline.
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
@ -323,6 +325,7 @@ def decode_dataset(
|
|||||||
return text
|
return text
|
||||||
elif normalize == "m2met":
|
elif normalize == "m2met":
|
||||||
import re
|
import re
|
||||||
|
|
||||||
text = text.replace(" ", "")
|
text = text.replace(" ", "")
|
||||||
text = text.replace("<sil>", "")
|
text = text.replace("<sil>", "")
|
||||||
text = text.replace("<%>", "")
|
text = text.replace("<%>", "")
|
||||||
@ -348,6 +351,7 @@ def decode_dataset(
|
|||||||
text = text.replace("、", "")
|
text = text.replace("、", "")
|
||||||
text = text.replace("?", "")
|
text = text.replace("?", "")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
2
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Executable file → Normal file
2
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Executable file → Normal file
@ -244,4 +244,4 @@ class MultiDataset:
|
|||||||
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
||||||
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
||||||
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
||||||
}
|
}
|
||||||
|
@ -65,8 +65,8 @@ from torch.cuda.amp import GradScaler
|
|||||||
from torch.nn.functional import pad as pad_tensor
|
from torch.nn.functional import pad as pad_tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
|
||||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||||
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
@ -458,6 +458,7 @@ def compute_loss(
|
|||||||
return text
|
return text
|
||||||
elif normalize == "m2met":
|
elif normalize == "m2met":
|
||||||
import re
|
import re
|
||||||
|
|
||||||
text = text.replace(" ", "")
|
text = text.replace(" ", "")
|
||||||
text = text.replace("<sil>", "")
|
text = text.replace("<sil>", "")
|
||||||
text = text.replace("<%>", "")
|
text = text.replace("<%>", "")
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
|
from typing import Dict, Iterable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import whisper
|
import whisper
|
||||||
from torch import Tensor
|
from torch import Tensor, nn
|
||||||
from torch import nn
|
from whisper.model import LayerNorm, ResidualAttentionBlock
|
||||||
from typing import Dict, Iterable, Optional
|
|
||||||
from whisper.model import ResidualAttentionBlock, LayerNorm
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
@ -19,10 +20,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
|||||||
self.token_embedding(x)
|
self.token_embedding(x)
|
||||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
)
|
)
|
||||||
x = (
|
x = x + self.positional_embedding[offset : offset + x.shape[1]]
|
||||||
x
|
|
||||||
+ self.positional_embedding[offset : offset + x.shape[1]]
|
|
||||||
)
|
|
||||||
x = x.to(xa.dtype)
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
# for block in self.blocks:
|
# for block in self.blocks:
|
||||||
@ -39,6 +37,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def replace_whisper_decoder_forward():
|
def replace_whisper_decoder_forward():
|
||||||
"""
|
"""
|
||||||
This function monkey patches the forward method of the whisper encoder.
|
This function monkey patches the forward method of the whisper encoder.
|
||||||
|
@ -20,7 +20,7 @@
|
|||||||
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
|
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
|
||||||
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 |
|
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 |
|
||||||
|
|
||||||
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
|
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
|
||||||
|
|
||||||
<details><summary> Detail all models </summary><p>
|
<details><summary> Detail all models </summary><p>
|
||||||
|
|
||||||
|
@ -29,9 +29,10 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import kaldialign
|
import kaldialign
|
||||||
|
from speechio_norm import TextNorm
|
||||||
|
|
||||||
from icefall.utils import store_transcripts, write_error_stats
|
from icefall.utils import store_transcripts, write_error_stats
|
||||||
from speechio_norm import TextNorm
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -140,6 +141,7 @@ def get_filenames(
|
|||||||
results.append(whisper_filename)
|
results.append(whisper_filename)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -14,10 +14,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -49,6 +51,7 @@ def get_parser():
|
|||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def load_fixed_text(fixed_text_path):
|
def load_fixed_text(fixed_text_path):
|
||||||
"""
|
"""
|
||||||
fixed text format
|
fixed text format
|
||||||
@ -57,32 +60,38 @@ def load_fixed_text(fixed_text_path):
|
|||||||
load into a dict
|
load into a dict
|
||||||
"""
|
"""
|
||||||
fixed_text_dict = {}
|
fixed_text_dict = {}
|
||||||
with open(fixed_text_path, 'r') as f:
|
with open(fixed_text_path, "r") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
cut_id, text = line.strip().split(' ', 1)
|
cut_id, text = line.strip().split(" ", 1)
|
||||||
fixed_text_dict[cut_id] = text
|
fixed_text_dict[cut_id] = text
|
||||||
return fixed_text_dict
|
return fixed_text_dict
|
||||||
|
|
||||||
|
|
||||||
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
|
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
|
||||||
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
|
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
|
||||||
fixed_item = 0
|
fixed_item = 0
|
||||||
for i, cut in enumerate(manifest):
|
for i, cut in enumerate(manifest):
|
||||||
if i % 10000 == 0:
|
if i % 10000 == 0:
|
||||||
logging.info(f'Processing cut {i}, fixed {fixed_item}')
|
logging.info(f"Processing cut {i}, fixed {fixed_item}")
|
||||||
cut_id_orgin = cut.id
|
cut_id_orgin = cut.id
|
||||||
if cut_id_orgin.endswith('_sp0.9'):
|
if cut_id_orgin.endswith("_sp0.9"):
|
||||||
cut_id = cut_id_orgin[:-6]
|
cut_id = cut_id_orgin[:-6]
|
||||||
elif cut_id_orgin.endswith('_sp1.1'):
|
elif cut_id_orgin.endswith("_sp1.1"):
|
||||||
cut_id = cut_id_orgin[:-6]
|
cut_id = cut_id_orgin[:-6]
|
||||||
else:
|
else:
|
||||||
cut_id = cut_id_orgin
|
cut_id = cut_id_orgin
|
||||||
if cut_id in fixed_text_dict:
|
if cut_id in fixed_text_dict:
|
||||||
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
|
assert (
|
||||||
|
len(cut.supervisions) == 1
|
||||||
|
), f"cut {cut_id} has {len(cut.supervisions)} supervisions"
|
||||||
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
|
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
|
||||||
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
|
logging.info(
|
||||||
|
f"Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}"
|
||||||
|
)
|
||||||
cut.supervisions[0].text = fixed_text_dict[cut_id]
|
cut.supervisions[0].text = fixed_text_dict[cut_id]
|
||||||
fixed_item += 1
|
fixed_item += 1
|
||||||
manifest_writer.write(cut)
|
manifest_writer.write(cut)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
@ -92,23 +101,26 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logging.info(vars(args))
|
logging.info(vars(args))
|
||||||
|
|
||||||
fixed_text_path = args.manifest_dir + 'text.fix'
|
fixed_text_path = args.manifest_dir + "text.fix"
|
||||||
fixed_text_dict = load_fixed_text(fixed_text_path)
|
fixed_text_dict = load_fixed_text(fixed_text_path)
|
||||||
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
|
logging.info(f"Loaded {len(fixed_text_dict)} fixed texts")
|
||||||
|
|
||||||
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
|
dev_manifest_path = args.manifest_dir + "cuts_DEV.jsonl.gz"
|
||||||
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
|
fixed_dev_manifest_path = args.manifest_dir + "cuts_DEV_fixed.jsonl.gz"
|
||||||
logging.info(f'Loading dev manifest from {dev_manifest_path}')
|
logging.info(f"Loading dev manifest from {dev_manifest_path}")
|
||||||
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
|
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
|
||||||
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
|
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
|
||||||
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
|
logging.info(f"Fixed dev manifest saved to {fixed_dev_manifest_path}")
|
||||||
|
|
||||||
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
|
manifest_path = args.manifest_dir + f"cuts_{args.training_subset}.jsonl.gz"
|
||||||
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
|
fixed_manifest_path = (
|
||||||
logging.info(f'Loading manifest from {manifest_path}')
|
args.manifest_dir + f"cuts_{args.training_subset}_fixed.jsonl.gz"
|
||||||
|
)
|
||||||
|
logging.info(f"Loading manifest from {manifest_path}")
|
||||||
cuts_manifest = load_manifest_lazy(manifest_path)
|
cuts_manifest = load_manifest_lazy(manifest_path)
|
||||||
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
|
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
|
||||||
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
|
logging.info(f"Fixed training manifest saved to {fixed_manifest_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -424,4 +424,4 @@ if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then
|
|||||||
python local/fix_manifest.py \
|
python local/fix_manifest.py \
|
||||||
--fixed-transcript-path data/fbank/text.fix \
|
--fixed-transcript-path data/fbank/text.fix \
|
||||||
--training-subset L
|
--training-subset L
|
||||||
fi
|
fi
|
||||||
|
@ -38,13 +38,13 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import os
|
|
||||||
import deepspeed
|
import deepspeed
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
|
@ -25,7 +25,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
log "You need to run the prepare.sh first."
|
log "You need to run the prepare.sh first."
|
||||||
exit -1
|
exit -1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
python ./zipformer/train.py \
|
python ./zipformer/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
@ -105,11 +105,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
--encoder-dim 128,128,128,128,128,128 \
|
--encoder-dim 128,128,128,128,128,128 \
|
||||||
--encoder-unmasked-dim 128,128,128,128,128,128 \
|
--encoder-unmasked-dim 128,128,128,128,128,128 \
|
||||||
--causal 1
|
--causal 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 2: Finetune the model"
|
log "Stage 2: Finetune the model"
|
||||||
|
|
||||||
# The following configuration of lr schedule should work well
|
# The following configuration of lr schedule should work well
|
||||||
# You may also tune the following parameters to adjust learning rate schedule
|
# You may also tune the following parameters to adjust learning rate schedule
|
||||||
base_lr=0.0005
|
base_lr=0.0005
|
||||||
@ -201,4 +201,4 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
--encoder-dim 128,128,128,128,128,128 \
|
--encoder-dim 128,128,128,128,128,128 \
|
||||||
--encoder-unmasked-dim 128,128,128,128,128,128 \
|
--encoder-unmasked-dim 128,128,128,128,128,128 \
|
||||||
--causal 1
|
--causal 1
|
||||||
fi
|
fi
|
||||||
|
Loading…
x
Reference in New Issue
Block a user