mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix lint
This commit is contained in:
parent
b970ba569a
commit
838bf223f3
@ -57,8 +57,8 @@ from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
@ -297,6 +297,7 @@ def decode_one_batch(
|
||||
print(hyps)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
@ -314,6 +315,7 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
|
||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
@ -323,6 +325,7 @@ def decode_dataset(
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
@ -348,6 +351,7 @@ def decode_dataset(
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
0
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Executable file → Normal file
0
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Executable file → Normal file
@ -65,8 +65,8 @@ from torch.cuda.amp import GradScaler
|
||||
from torch.nn.functional import pad as pad_tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
@ -458,6 +458,7 @@ def compute_loss(
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import whisper
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from typing import Dict, Iterable, Optional
|
||||
from whisper.model import ResidualAttentionBlock, LayerNorm
|
||||
import numpy as np
|
||||
from torch import Tensor, nn
|
||||
from whisper.model import LayerNorm, ResidualAttentionBlock
|
||||
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
@ -19,10 +20,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = (
|
||||
x
|
||||
+ self.positional_embedding[offset : offset + x.shape[1]]
|
||||
)
|
||||
x = x + self.positional_embedding[offset : offset + x.shape[1]]
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
# for block in self.blocks:
|
||||
@ -39,6 +37,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def replace_whisper_decoder_forward():
|
||||
"""
|
||||
This function monkey patches the forward method of the whisper encoder.
|
||||
|
@ -29,9 +29,10 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import kaldialign
|
||||
from speechio_norm import TextNorm
|
||||
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
from speechio_norm import TextNorm
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -140,6 +141,7 @@ def get_filenames(
|
||||
results.append(whisper_filename)
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -14,10 +14,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -49,6 +51,7 @@ def get_parser():
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_fixed_text(fixed_text_path):
|
||||
"""
|
||||
fixed text format
|
||||
@ -57,33 +60,39 @@ def load_fixed_text(fixed_text_path):
|
||||
load into a dict
|
||||
"""
|
||||
fixed_text_dict = {}
|
||||
with open(fixed_text_path, 'r') as f:
|
||||
with open(fixed_text_path, "r") as f:
|
||||
for line in f:
|
||||
cut_id, text = line.strip().split(' ', 1)
|
||||
cut_id, text = line.strip().split(" ", 1)
|
||||
fixed_text_dict[cut_id] = text
|
||||
return fixed_text_dict
|
||||
|
||||
|
||||
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
|
||||
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
|
||||
fixed_item = 0
|
||||
for i, cut in enumerate(manifest):
|
||||
if i % 10000 == 0:
|
||||
logging.info(f'Processing cut {i}, fixed {fixed_item}')
|
||||
logging.info(f"Processing cut {i}, fixed {fixed_item}")
|
||||
cut_id_orgin = cut.id
|
||||
if cut_id_orgin.endswith('_sp0.9'):
|
||||
if cut_id_orgin.endswith("_sp0.9"):
|
||||
cut_id = cut_id_orgin[:-6]
|
||||
elif cut_id_orgin.endswith('_sp1.1'):
|
||||
elif cut_id_orgin.endswith("_sp1.1"):
|
||||
cut_id = cut_id_orgin[:-6]
|
||||
else:
|
||||
cut_id = cut_id_orgin
|
||||
if cut_id in fixed_text_dict:
|
||||
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
|
||||
assert (
|
||||
len(cut.supervisions) == 1
|
||||
), f"cut {cut_id} has {len(cut.supervisions)} supervisions"
|
||||
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
|
||||
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
|
||||
logging.info(
|
||||
f"Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}"
|
||||
)
|
||||
cut.supervisions[0].text = fixed_text_dict[cut_id]
|
||||
fixed_item += 1
|
||||
manifest_writer.write(cut)
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
@ -92,23 +101,26 @@ def main():
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
fixed_text_path = args.manifest_dir + 'text.fix'
|
||||
fixed_text_path = args.manifest_dir + "text.fix"
|
||||
fixed_text_dict = load_fixed_text(fixed_text_path)
|
||||
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
|
||||
logging.info(f"Loaded {len(fixed_text_dict)} fixed texts")
|
||||
|
||||
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
|
||||
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
|
||||
logging.info(f'Loading dev manifest from {dev_manifest_path}')
|
||||
dev_manifest_path = args.manifest_dir + "cuts_DEV.jsonl.gz"
|
||||
fixed_dev_manifest_path = args.manifest_dir + "cuts_DEV_fixed.jsonl.gz"
|
||||
logging.info(f"Loading dev manifest from {dev_manifest_path}")
|
||||
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
|
||||
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
|
||||
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
|
||||
logging.info(f"Fixed dev manifest saved to {fixed_dev_manifest_path}")
|
||||
|
||||
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
|
||||
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
|
||||
logging.info(f'Loading manifest from {manifest_path}')
|
||||
manifest_path = args.manifest_dir + f"cuts_{args.training_subset}.jsonl.gz"
|
||||
fixed_manifest_path = (
|
||||
args.manifest_dir + f"cuts_{args.training_subset}_fixed.jsonl.gz"
|
||||
)
|
||||
logging.info(f"Loading manifest from {manifest_path}")
|
||||
cuts_manifest = load_manifest_lazy(manifest_path)
|
||||
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
|
||||
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
|
||||
logging.info(f"Fixed training manifest saved to {fixed_manifest_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -38,13 +38,13 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import os
|
||||
import deepspeed
|
||||
import k2
|
||||
import optim
|
||||
|
Loading…
x
Reference in New Issue
Block a user