mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
The training script produce WER of 2.57% on librispeech test-clean (#13)
* Add grad_clip and weight-decay, small fix of dataloader and masking * Add RESULTS.md
This commit is contained in:
parent
caa0b9e942
commit
ef233486ae
23
egs/librispeech/ASR/RESULTS.md
Normal file
23
egs/librispeech/ASR/RESULTS.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
## Results
|
||||||
|
|
||||||
|
### LibriSpeech BPE training results (Conformer-CTC)
|
||||||
|
#### 2021-08-19
|
||||||
|
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/13
|
||||||
|
|
||||||
|
TensorBoard log is available at https://tensorboard.dev/experiment/GnRzq8WWQW62dK4bklXBTg/#scalars
|
||||||
|
|
||||||
|
Pretrained model is available at https://huggingface.co/pkufool/conformer_ctc
|
||||||
|
|
||||||
|
The best decoding results (WER) are listed below, we got this results by averaging models from epoch 15 to 34, and using `attention-decoder` decoder with num_paths equals to 100.
|
||||||
|
|
||||||
|
||test-clean|test-other|
|
||||||
|
|--|--|--|
|
||||||
|
|WER| 2.57% | 5.94% |
|
||||||
|
|
||||||
|
To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the WER above are also listed below.
|
||||||
|
|
||||||
|
||lm_scale|attention_scale|
|
||||||
|
|--|--|--|
|
||||||
|
|test-clean|1.3|1.2|
|
||||||
|
|test-other|1.2|1.1|
|
||||||
|
|
@ -284,6 +284,7 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
tot_num_batches = len(dl)
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -313,6 +314,8 @@ def decode_dataset(
|
|||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
|
f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
|
||||||
|
f"{num_cuts}"
|
||||||
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
|
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
@ -373,7 +376,7 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-decode")
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ import torch.nn as nn
|
|||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.nn.utils import clip_grad_value_
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam
|
||||||
@ -145,7 +146,6 @@ def get_params() -> AttributeDict:
|
|||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
#
|
|
||||||
"accum_grad": 1,
|
"accum_grad": 1,
|
||||||
"att_rate": 0.7,
|
"att_rate": 0.7,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
@ -463,7 +463,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_value_(model.parameters(), 5.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
@ -171,6 +171,8 @@ class AsrDataModule(DataModule):
|
|||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
|
bucket_method='equal_duration',
|
||||||
|
drop_last=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SingleCutSampler.")
|
||||||
@ -184,8 +186,8 @@ class AsrDataModule(DataModule):
|
|||||||
train,
|
train,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=4,
|
num_workers=2,
|
||||||
persistent_workers=True,
|
persistent_workers=False,
|
||||||
)
|
)
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
@ -214,7 +216,7 @@ class AsrDataModule(DataModule):
|
|||||||
sampler=valid_sampler,
|
sampler=valid_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=2,
|
num_workers=2,
|
||||||
persistent_workers=True,
|
persistent_workers=False,
|
||||||
)
|
)
|
||||||
return valid_dl
|
return valid_dl
|
||||||
|
|
||||||
|
@ -610,7 +610,7 @@ def rescore_with_attention_decoder(
|
|||||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||||
# to the input path index.
|
# to the input path index.
|
||||||
# new2old.numel() == unique_word_seqs.tot_size(1)
|
# new2old.numel() == unique_word_seq.tot_size(1)
|
||||||
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
||||||
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user