update speechio whisper ft results

This commit is contained in:
Yuekai Zhang 2024-04-24 18:57:34 +08:00
parent df36f93bd8
commit b970ba569a
13 changed files with 1629 additions and 213 deletions

View File

@ -1,5 +1,48 @@
## Results ## Results
### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
#### Whisper
[./whisper](./whisper)
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Greedy Search | 23.45 | 25.42 | 0.78 | 0.83 | 2.75 | 2.93 | 17.11 | 2.68 | 2.33 | 4.97 | 2.02 | 6.34 | 5.06 | 8.38 | 6.94 |
Command for training is:
```bash
pip install -r whisper/requirements.txt
# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
```
Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
ln -s icefall_asr_multi-hans-zh_whisper/v1/epoch-2-avg4.pt whisper/exp_large_v2/epoch-999.pt
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model ### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall. This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.

View File

@ -226,8 +226,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare WenetSpeech" log "Stage 11: Prepare WenetSpeech"
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
cd data/fbank cd data/fbank
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
@ -299,15 +299,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Compute KeSpeech fbank for test/dev" log "Compute KeSpeech fbank for test/dev"
./local/compute_fbank_kespeech_dev_test.py ./local/compute_fbank_kespeech_dev_test.py
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then
pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz")
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
fi
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then
pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz")
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
fi
touch data/fbank/.kespeech.done touch data/fbank/.kespeech.done
fi fi
fi fi

48
egs/multi_zh-hans/ASR/whisper/decode.py Normal file → Executable file
View File

@ -58,6 +58,7 @@ from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from zhconv import convert from zhconv import convert
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
@ -214,7 +215,7 @@ def get_parser():
"--model-name", "--model-name",
type=str, type=str,
default="large-v2", default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use. help="""The model name to use.
""", """,
) )
@ -226,6 +227,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction", help="replace whisper encoder forward method to remove input length restriction",
) )
parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)
return parser return parser
@ -289,7 +297,6 @@ def decode_one_batch(
print(hyps) print(hyps)
return {"beam-search": hyps} return {"beam-search": hyps}
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
@ -307,6 +314,40 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "beam-search". Return a dict, whose key may be "beam-search".
""" """
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
results = [] results = []
num_cuts = 0 num_cuts = 0
@ -331,6 +372,7 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
@ -430,6 +472,8 @@ def main():
if params.remove_whisper_encoder_input_length_restriction: if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu") model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0: if params.epoch > 0:
if params.avg > 1: if params.avg > 1:

93
egs/multi_zh-hans/ASR/whisper/multi_dataset.py Normal file → Executable file
View File

@ -43,7 +43,7 @@ class MultiDataset:
- thchs_30_cuts_train.jsonl.gz - thchs_30_cuts_train.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz - kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz - kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
- wenetspeech/cuts_L.jsonl.gz - wenetspeech/cuts_L_fixed.jsonl.gz
""" """
self.fbank_dir = Path(fbank_dir) self.fbank_dir = Path(fbank_dir)
@ -105,7 +105,7 @@ class MultiDataset:
# WeNetSpeech # WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode") logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy( wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz" self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
) )
# KeSpeech # KeSpeech
@ -124,10 +124,10 @@ class MultiDataset:
aishell_4_L_cuts, aishell_4_L_cuts,
aishell_4_M_cuts, aishell_4_M_cuts,
aishell_4_S_cuts, aishell_4_S_cuts,
alimeeting_cuts,
stcmds_cuts, stcmds_cuts,
primewords_cuts, primewords_cuts,
magicdata_cuts, magicdata_cuts,
alimeeting_cuts,
wenetspeech_L_cuts, wenetspeech_L_cuts,
kespeech_1_cuts, kespeech_1_cuts,
kespeech_2_cuts, kespeech_2_cuts,
@ -138,10 +138,10 @@ class MultiDataset:
len(aishell_4_L_cuts), len(aishell_4_L_cuts),
len(aishell_4_M_cuts), len(aishell_4_M_cuts),
len(aishell_4_S_cuts), len(aishell_4_S_cuts),
len(alimeeting_cuts),
len(stcmds_cuts), len(stcmds_cuts),
len(primewords_cuts), len(primewords_cuts),
len(magicdata_cuts), len(magicdata_cuts),
len(alimeeting_cuts),
len(wenetspeech_L_cuts), len(wenetspeech_L_cuts),
len(kespeech_1_cuts), len(kespeech_1_cuts),
len(kespeech_2_cuts), len(kespeech_2_cuts),
@ -151,55 +151,13 @@ class MultiDataset:
def dev_cuts(self) -> CutSet: def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts") logging.info("About to get multidataset dev cuts")
# AISHELL
logging.info("Loading Aishell DEV set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 DEV set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting DEV set in lazy mode")
alimeeting_dev_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData DEV set in lazy mode")
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech DEV set in lazy mode")
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)
# WeNetSpeech # WeNetSpeech
logging.info("Loading WeNetSpeech DEV set in lazy mode") logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy( wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
) )
return wenetspeech_dev_cuts return wenetspeech_dev_cuts
# return [
# aishell_dev_cuts,
# aishell2_dev_cuts,
# alimeeting_dev_cuts,
# magicdata_dev_cuts,
# kespeech_dev_phase1_cuts,
# kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts,
# ]
def test_cuts(self) -> Dict[str, CutSet]: def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts") logging.info("About to get multidataset test cuts")
@ -267,30 +225,23 @@ class MultiDataset:
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz" self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
) )
wenetspeech_dev_cuts = load_manifest_lazy( wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
) )
return { return {
"aishell-2_test": aishell2_test_cuts, "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
"aishell-4": aishell4_test_cuts, # "aishell_test": aishell_test_cuts,
"magicdata_test": magicdata_test_cuts, # "aishell_dev": aishell_dev_cuts,
"kespeech-asr_test": kespeech_test_cuts, # "ali-meeting_test": alimeeting_test_cuts,
} # "ali-meeting_eval": alimeeting_eval_cuts,
# "aishell-4_test": aishell4_test_cuts,
# return { # "aishell-2_test": aishell2_test_cuts,
# "alimeeting_test": alimeeting_test_cuts, # "aishell-2_dev": aishell2_dev_cuts,
# "alimeeting_eval": alimeeting_eval_cuts, # "magicdata_test": magicdata_test_cuts,
# "aishell_test": aishell_test_cuts, # "magicdata_dev": magicdata_dev_cuts,
# "aishell_dev": aishell_dev_cuts, # "kespeech-asr_test": kespeech_test_cuts,
# "aishell-2_test": aishell2_test_cuts, # "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "aishell-2_dev": aishell2_dev_cuts, # "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "aishell-4": aishell4_test_cuts, # "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "magicdata_test": magicdata_test_cuts, # "wenetspeech_dev": wenetspeech_dev_cuts,
# "magicdata_dev": magicdata_dev_cuts, }
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
# }

51
egs/multi_zh-hans/ASR/whisper/train.py Normal file → Executable file
View File

@ -66,6 +66,7 @@ from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -146,7 +147,7 @@ def get_parser():
"--model-name", "--model-name",
type=str, type=str,
default="large-v2", default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use. help="""The model name to use.
""", """,
) )
@ -232,6 +233,13 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)
parser = deepspeed.add_config_arguments(parser) parser = deepspeed.add_config_arguments(parser)
return parser return parser
@ -441,6 +449,41 @@ def compute_loss(
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
return torch.stack([tensor for tensor in padded_tensors], dim=0) return torch.stack([tensor for tensor in padded_tensors], dim=0)
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
max_frames = params.max_duration * 1000 // params.frame_shift_ms max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames) batch = filter_uneven_sized_batch(batch, allowed_max_frames)
@ -459,7 +502,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
# remove spaces in texts # remove spaces in texts
texts = [text.replace(" ", "") for text in texts] texts = [normalize_text_alimeeting(text) for text in texts]
text_tokens_list = [ text_tokens_list = [
list(tokenizer.sot_sequence_including_notimestamps) list(tokenizer.sot_sequence_including_notimestamps)
@ -759,6 +802,8 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu") model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads del model.alignment_heads
@ -824,7 +869,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
512 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -0,0 +1,47 @@
import torch
import torch.nn.functional as F
import whisper
from torch import Tensor
from torch import nn
from typing import Dict, Iterable, Optional
from whisper.model import ResidualAttentionBlock, LayerNorm
import numpy as np
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = (
x
+ self.positional_embedding[offset : offset + x.shape[1]]
)
x = x.to(xa.dtype)
# for block in self.blocks:
# x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
# use architecture from the distill whisper model
# see https://github.com/huggingface/distil-whisper
x = self.blocks[0](x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.blocks[-1](x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
def replace_whisper_decoder_forward():
"""
This function monkey patches the forward method of the whisper encoder.
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
"""
whisper.model.TextDecoder.forward = forward

View File

@ -2,41 +2,70 @@
### SpeechIO Test Set Decoding Results ### SpeechIO Test Set Decoding Results
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
#### **Unlocked** SpeechIO test sets (ZH00001 ~ ZH00026)
| Rank 排名 | Model 模型 | CER 字错误率 | Date 时间 |
| --- | --- | --- | --- |
| 1 | ximalaya_api_zh | 1.72% | 2023.12 |
| 2 | aliyun_ftasr_api_zh | 1.85% | 2023.12 |
| 3 | microsoft_batch_zh | 2.40% | 2023.12 |
| 4 | bilibili_api_zh | 2.90% | 2023.09 |
| 5 | tencent_api_zh | 3.18% | 2023.12 |
| 6 | iflytek_lfasr_api_zh | 3.32% | 2023.12 |
| 7 | aispeech_api_zh | 3.62% | 2023.12 |
| 8 | **whisper-large-ft-v1** | **4.45%** | 2024.04 |
| 9 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 |
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
<details><summary> Detail all models </summary><p>
| Model | Training Set | Note |
|----------------------------------------------------------------------------------------------------------|---------------|-----------------------------------------------------|
|[zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24)| multi-hans-zh | decoding with transducer head and blank penalty 2.0 |
|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper)| wenetspeech | |
|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py)|
</details>
<details><summary> Detail all results (字错误率 CER %) </summary><p>
| Test Set ID | 测试场景&内容领域 | bilibili_api_zh (2023.09) | whisper-large-ft-v0 | whisper-large-ft-v1 | zipformer |
|----------------------|-------------------------------|-----------------|---------|-----------|-----------|
| Avg (01-26) | | 2.9 | 6.34 | 4.45 | 6.17 |
| SPEECHIO_ASR_ZH00001 | 新闻联播 | 0.54 | 1.42 | 1.11 | 1.37 |
| SPEECHIO_ASR_ZH00002 | 访谈 鲁豫有约 | 2.78 | 4.76 | 3.25 | 4.67 |
| SPEECHIO_ASR_ZH00003 | 电视节目 天下足球 | 0.81 | 2.17 | 1.64 | 2.71 |
| SPEECHIO_ASR_ZH00004 | 场馆演讲 罗振宇跨年 | 1.48 | 2.53 | 1.91 | 2.54 |
| SPEECHIO_ASR_ZH00005 | 在线教育 李永乐 科普 | 1.47 | 4.27 | 1.98 | 3.12 |
| SPEECHIO_ASR_ZH00006 | 直播 王者荣耀 张大仙&骚白 | 5.85 | 12.55 | 9.28 | 12.86 |
| SPEECHIO_ASR_ZH00007 | 直播 带货 李佳琪&薇娅 | 6.19 | 13.38 | 11.14 | 14.58 |
| SPEECHIO_ASR_ZH00008 | 线下培训 老罗语录 | 3.68 | 9.56 | 7.47 | 9.05 |
| SPEECHIO_ASR_ZH00009 | 播客 故事FM | 3.18 | 5.66 | 3.78 | 5.4 |
| SPEECHIO_ASR_ZH00010 | 播客 创业内幕 | 3.51 | 7.84 | 4.32 | 6.4 |
| SPEECHIO_ASR_ZH00011 | 在线教育 罗翔 刑法法考 | 1.77 | 3.22 | 2.49 | 3.12 |
| SPEECHIO_ASR_ZH00012 | 在线教育 张雪峰 考研 | 2.11 | 5.98 | 3.09 | 4.41 |
| SPEECHIO_ASR_ZH00013 | 短视频 影剪 谷阿莫&牛叔说电影 | 2.97 | 5.91 | 3.78 | 6.56 |
| SPEECHIO_ASR_ZH00014 | 短视频 美式&烹饪 | 3.56 | 6.03 | 5.28 | 8.14 |
| SPEECHIO_ASR_ZH00015 | 评书 单田芳 白眉大侠 | 4.72 | 8.77 | 7.97 | 9.1 |
| SPEECHIO_ASR_ZH00016 | 相声 德云社专场 | 3.01 | 5.24 | 4.41 | 5.59 |
| SPEECHIO_ASR_ZH00017 | 脱口秀 吐槽大会 | 2.93 | 7.05 | 3.27 | 5.17 |
| SPEECHIO_ASR_ZH00018 | 少儿卡通 小猪佩奇&熊出没 | 1.98 | 3.53 | 3.21 | 4.15 |
| SPEECHIO_ASR_ZH00019 | 体育赛事解说 NBA比赛 | 2.32 | 6.89 | 4.24 | 6.66 |
| SPEECHIO_ASR_ZH00020 | 纪录片 篮球人物 | 1.51 | 4.16 | 2.96 | 4.2 |
| SPEECHIO_ASR_ZH00021 | 短视频 汽车之家 汽车评测 | 1.75 | 4.77 | 2.77 | 4.17 |
| SPEECHIO_ASR_ZH00022 | 短视频 小艾大叔 豪宅带看 | 3.29 | 6.35 | 5.66 | 6.72 |
| SPEECHIO_ASR_ZH00023 | 短视频 开箱视频 Zeal&无聊开箱 | 2.18 | 8.99 | 4.45 | 7.94 |
| SPEECHIO_ASR_ZH00024 | 短视频 付老师 农业种植 | 4.80 | 10.81 | 6.25 | 8.64 |
| SPEECHIO_ASR_ZH00025 | 线下课堂 石国鹏 古希腊哲学 | 3.32 | 8.41 | 5.8 | 8.54 |
| SPEECHIO_ASR_ZH00026 | 广播电台节目 张震鬼故事 | 3.70 | 4.52 | 4.11 | 4.67 |
</details>
Command for decoding using fine-tuned whisper: Command for decoding using fine-tuned whisper:
```bash ```bash
@ -76,16 +105,6 @@ mv words.txt ./data/lang_bpe_2000/
--manifest-dir data/fbank_kaldi \ --manifest-dir data/fbank_kaldi \
--decoding-method greedy_search --decoding-method greedy_search
``` ```
Command for fusion the above decoding results from whisper and zipformer:
```bash
python local/whisper_zipformer_fusion.py \
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
--output-log-dir ./results_fusion
```
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
SpeechIO fbank features, decoding scripts, logs, and decoding results SpeechIO fbank features, decoding scripts, logs, and decoding results
are available at are available at

View File

@ -21,7 +21,7 @@ Since whisper model is more likely to make deletion errors and zipformer model i
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors. we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
Usage: Usage:
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion python whisper_zipformer_fusion.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm
""" """
import argparse import argparse
@ -31,38 +31,43 @@ from typing import Dict, List, Optional, Tuple
import kaldialign import kaldialign
from icefall.utils import store_transcripts, write_error_stats from icefall.utils import store_transcripts, write_error_stats
from speechio_norm import TextNorm
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument( parser.add_argument(
"--whisper-log-dir", "--model-log-dir",
type=str, type=str,
default="./recogs_whisper", default="./recogs_whisper",
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt", help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
) )
parser.add_argument(
"--zipformer-log-dir",
type=str,
default="./recogs_zipformer",
help="The directory to store the zipformer logs",
)
parser.add_argument( parser.add_argument(
"--output-log-dir", "--output-log-dir",
type=str, type=str,
default="./results_fusion", default="./results_whisper_norm",
help="The directory to store the fusion logs", help="The directory to store the normalized whisper logs",
) )
return parser return parser
def save_results( def save_results_with_speechio_text_norm(
res_dir: Path, res_dir: Path,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
normalizer = TextNorm()
# normlize items in results_dict
for key, results in results_dict.items():
results_norm = []
for item in results:
wav_name, ref, hyp = item
ref = normalizer(ref)
hyp = normalizer(hyp)
results_norm.append((wav_name, ref, hyp))
results_dict[key] = results_norm
test_set_wers = dict() test_set_wers = dict()
suffix = "epoch-999-avg-1" suffix = "epoch-999-avg-1"
@ -120,11 +125,9 @@ def extract_hyp_ref_wavname(filename):
return hyps, refs, wav_name return hyps, refs, wav_name
def get_pair_filenames( def get_filenames(
whisper_log_dir, whisper_log_dir,
zipformer_log_dir,
whisper_suffix="beam-search-epoch-999-avg-1", whisper_suffix="beam-search-epoch-999-avg-1",
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
): ):
results = [] results = []
start_index, end_index = 0, 26 start_index, end_index = 0, 26
@ -134,80 +137,23 @@ def get_pair_filenames(
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}") dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
for partition in dataset_parts: for partition in dataset_parts:
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt" whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
zipformer_filename = ( results.append(whisper_filename)
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
)
results.append((whisper_filename, zipformer_filename))
return results return results
def fusion_hyps_trust_substituion_insertion(
hyps_whisper, hyps_zipformer, refs, ERR="*"
):
"""
alignment example:
[('', ''), ('', '*'), ('', ''), ('', ''), ('', ''), ('', ''), ('*', '')]
left is whisper, right is zipformer
for whisper substitution, use left
for whisper insertion, use left
for whisper deletion, use right
"""
hyps_fusion = []
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
ali = kaldialign.align(hyp_w, hyp_z, ERR)
hyp_f = ""
for a in ali:
if a[0] == ERR:
hyp_f += a[1]
else:
hyp_f += a[0]
hyps_fusion.append(hyp_f)
return hyps_fusion
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
"""
alignment example:
[('', ''), ('', '*'), ('', ''), ('', ''), ('', ''), ('', ''), ('*', '')]
left is whisper, right is zipformer
for whisper substitution, use left
for whisper insertion, use right
for whisper deletion, use right
"""
hyps_fusion = []
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
ali = kaldialign.align(hyp_w, hyp_z, ERR)
hyp_f = ""
for a in ali:
if a[0] == ERR:
hyp_f += a[1]
elif a[1] == ERR:
pass
else:
hyp_f += a[0]
hyps_fusion.append(hyp_f)
return hyps_fusion
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
# mkdir output_log_dir # mkdir output_log_dir
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True) Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir) filenames = get_filenames(args.model_log_dir)
for pair in pair_logs: for filename in filenames:
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0]) hyps, refs, wav_name = extract_hyp_ref_wavname(filename)
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1]) partition_name = filename.split("/")[-1].split("-")[1]
hyps_fusion = fusion_hyps_trust_substituion_insertion( save_results_with_speechio_text_norm(
hyps_whisper, hyps_zipformer, refs
)
partition_name = pair[0].split("/")[-1].split("-")[1]
save_results(
Path(args.output_log_dir), Path(args.output_log_dir),
partition_name, partition_name,
{"fusion": list(zip(wav_name, refs, hyps_fusion))}, {"norm": list(zip(wav_name, refs, hyps))},
) )
print(f"Processed {partition_name}") print(f"Processed {partition_name}")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright 2024 author: Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import argparse
from lhotse import CutSet, load_manifest_lazy
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--fixed-transcript-path",
type=str,
default="data/fbank/text.fix",
help="""
See https://github.com/wenet-e2e/WenetSpeech/discussions/54
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix
""",
)
parser.add_argument(
"--manifest-dir",
type=str,
default="data/fbank/",
help="Directory to store the manifest files",
)
parser.add_argument(
"--training-subset",
type=str,
default="L",
help="The training subset for wenetspeech.",
)
return parser
def load_fixed_text(fixed_text_path):
"""
fixed text format
X0000016287_92761015_S00001 我是徐涛
X0000016287_92761015_S00002 狄更斯的PICK WEEK PAPERS斯
load into a dict
"""
fixed_text_dict = {}
with open(fixed_text_path, 'r') as f:
for line in f:
cut_id, text = line.strip().split(' ', 1)
fixed_text_dict[cut_id] = text
return fixed_text_dict
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
fixed_item = 0
for i, cut in enumerate(manifest):
if i % 10000 == 0:
logging.info(f'Processing cut {i}, fixed {fixed_item}')
cut_id_orgin = cut.id
if cut_id_orgin.endswith('_sp0.9'):
cut_id = cut_id_orgin[:-6]
elif cut_id_orgin.endswith('_sp1.1'):
cut_id = cut_id_orgin[:-6]
else:
cut_id = cut_id_orgin
if cut_id in fixed_text_dict:
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
cut.supervisions[0].text = fixed_text_dict[cut_id]
fixed_item += 1
manifest_writer.write(cut)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
fixed_text_path = args.manifest_dir + 'text.fix'
fixed_text_dict = load_fixed_text(fixed_text_path)
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
logging.info(f'Loading dev manifest from {dev_manifest_path}')
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
logging.info(f'Loading manifest from {manifest_path}')
cuts_manifest = load_manifest_lazy(manifest_path)
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
if __name__ == "__main__":
main()

View File

@ -416,3 +416,12 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then
python ./local/compile_lg.py --lang-dir $lang_dir python ./local/compile_lg.py --lang-dir $lang_dir
done done
fi fi
if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then
log "Stage 23: Modify transcript according to fixed results"
# See https://github.com/wenet-e2e/WenetSpeech/discussions/54
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix -O data/fbank/text.fix
python local/fix_manifest.py \
--fixed-transcript-path data/fbank/text.fix \
--training-subset L
fi

View File

@ -390,14 +390,14 @@ class WenetSpeechAsrDataModule:
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
cuts_train = load_manifest_lazy( cuts_train = load_manifest_lazy(
self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz" self.args.manifest_dir / f"cuts_{self.args.training_subset}_fixed.jsonl.gz"
) )
return cuts_train return cuts_train
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV_fixed.jsonl.gz")
@lru_cache() @lru_cache()
def test_net_cuts(self) -> List[CutSet]: def test_net_cuts(self) -> List[CutSet]:

View File

@ -44,6 +44,7 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import os
import deepspeed import deepspeed
import k2 import k2
import optim import optim
@ -145,7 +146,7 @@ def get_parser():
"--model-name", "--model-name",
type=str, type=str,
default="large-v2", default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use. help="""The model name to use.
""", """,
) )
@ -616,7 +617,9 @@ def train_one_epoch(
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
) )
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -803,7 +806,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
512 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -893,6 +896,7 @@ def run(rank, world_size, args):
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}", tag=f"epoch-{params.cur_epoch}",
) )
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
else: else:
save_checkpoint( save_checkpoint(
params=params, params=params,