Merge remote-tracking branch 'dan/master' into nbest-oracle

This commit is contained in:
Fangjun Kuang 2021-08-20 10:27:15 +08:00
commit 60211ce12a
5 changed files with 35 additions and 7 deletions

View File

@ -0,0 +1,23 @@
## Results
### LibriSpeech BPE training results (Conformer-CTC)
#### 2021-08-19
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/13
TensorBoard log is available at https://tensorboard.dev/experiment/GnRzq8WWQW62dK4bklXBTg/#scalars
Pretrained model is available at https://huggingface.co/pkufool/conformer_ctc
The best decoding results (WER) are listed below, we got this results by averaging models from epoch 15 to 34, and using `attention-decoder` decoder with num_paths equals to 100.
||test-clean|test-other|
|--|--|--|
|WER| 2.57% | 5.94% |
To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the WER above are also listed below.
||lm_scale|attention_scale|
|--|--|--|
|test-clean|1.3|1.2|
|test-other|1.2|1.1|

View File

@ -317,6 +317,7 @@ def decode_dataset(
results = [] results = []
num_cuts = 0 num_cuts = 0
tot_num_batches = len(dl)
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -346,6 +347,8 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
logging.info( logging.info(
f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
f"{num_cuts}"
f"batch {batch_idx}, cuts processed until now is {num_cuts}" f"batch {batch_idx}, cuts processed until now is {num_cuts}"
) )
return results return results
@ -406,7 +409,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode") setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started") logging.info("Decoding started")
logging.info(params) logging.info(params)

View File

@ -16,6 +16,7 @@ import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -145,7 +146,6 @@ def get_params() -> AttributeDict:
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
#
"accum_grad": 1, "accum_grad": 1,
"att_rate": 0.7, "att_rate": 0.7,
"attention_dim": 512, "attention_dim": 512,
@ -463,7 +463,7 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_value_(model.parameters(), 5.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item() loss_cpu = loss.detach().cpu().item()

View File

@ -171,6 +171,8 @@ class AsrDataModule(DataModule):
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=True, shuffle=True,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
bucket_method='equal_duration',
drop_last=True,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SingleCutSampler.")
@ -184,8 +186,8 @@ class AsrDataModule(DataModule):
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=4, num_workers=2,
persistent_workers=True, persistent_workers=False,
) )
return train_dl return train_dl
@ -214,7 +216,7 @@ class AsrDataModule(DataModule):
sampler=valid_sampler, sampler=valid_sampler,
batch_size=None, batch_size=None,
num_workers=2, num_workers=2,
persistent_workers=True, persistent_workers=False,
) )
return valid_dl return valid_dl

View File

@ -750,7 +750,7 @@ def rescore_with_attention_decoder(
# Since k2.ragged.unique_sequences will reorder paths within a seq, # Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index # `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index. # to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1) # new2old.numel() == unique_word_seq.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
word_seq, need_num_repeats=True, need_new2old_indexes=True word_seq, need_num_repeats=True, need_new2old_indexes=True
) )