mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add cr-ctc training for gigispeech
This commit is contained in:
parent
cf796eefed
commit
ca7dbb085e
1
egs/gigaspeech/ASR/zipformer/attention_decoder.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/attention_decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/attention_decoder.py
|
@ -88,6 +88,8 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
|
||||
from gigaspeech_scoring import asr_text_post_processing
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -274,6 +276,17 @@ def get_decoding_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def post_processing(
|
||||
results: List[Tuple[str, List[str], List[str]]],
|
||||
) -> List[Tuple[str, List[str], List[str]]]:
|
||||
new_results = []
|
||||
for key, ref, hyp in results:
|
||||
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
||||
new_results.append((key, new_ref, new_hyp))
|
||||
return new_results
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -567,6 +580,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = post_processing(results)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
@ -813,14 +827,12 @@ def main():
|
||||
args.return_cuts = True
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = gigaspeech.test_clean_cuts()
|
||||
test_other_cuts = gigaspeech.test_other_cuts()
|
||||
test_cuts = gigaspeech.test_cuts()
|
||||
|
||||
test_clean_dl = gigaspeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = gigaspeech.test_dataloaders(test_other_cuts)
|
||||
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
test_sets = ["test"]
|
||||
test_dl = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
|
1
egs/gigaspeech/ASR/zipformer/label_smoothing.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/label_smoothing.py
|
1
egs/gigaspeech/ASR/zipformer/spec_augment.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/spec_augment.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/zipformer/spec_augment.py
|
1453
egs/gigaspeech/ASR/zipformer/train_cr.py
Executable file
1453
egs/gigaspeech/ASR/zipformer/train_cr.py
Executable file
File diff suppressed because it is too large
Load Diff
1542
egs/gigaspeech/ASR/zipformer/train_cr_aed.py
Executable file
1542
egs/gigaspeech/ASR/zipformer/train_cr_aed.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,7 @@
|
||||
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
import k2
|
||||
@ -234,9 +235,12 @@ class TransformerDecoder(nn.Module):
|
||||
# construct attn_mask for self-attn modules
|
||||
padding_mask = make_pad_mask(x_lens) # (batch, tgt_len)
|
||||
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
attn_mask = torch.logical_or(
|
||||
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
|
||||
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len)
|
||||
torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len)
|
||||
) # (batch, seq_len, seq_len)
|
||||
|
||||
if memory is not None:
|
||||
@ -367,7 +371,9 @@ class MultiHeadAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = attention_dim // num_heads
|
||||
assert self.head_dim * num_heads == attention_dim, (
|
||||
self.head_dim, num_heads, attention_dim
|
||||
self.head_dim,
|
||||
num_heads,
|
||||
attention_dim,
|
||||
)
|
||||
self.dropout = dropout
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
@ -437,15 +443,19 @@ class MultiHeadAttention(nn.Module):
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.shape == (batch, 1, src_len)
|
||||
or attn_mask.shape == (batch, tgt_len, src_len)
|
||||
assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == (
|
||||
batch,
|
||||
tgt_len,
|
||||
src_len,
|
||||
), attn_mask.shape
|
||||
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf"))
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
attn_mask.unsqueeze(1), float("-inf")
|
||||
)
|
||||
|
||||
attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
@ -456,7 +466,11 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# (batch * head, tgt_len, head_dim)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape
|
||||
assert attn_output.shape == (
|
||||
batch * num_heads,
|
||||
tgt_len,
|
||||
head_dim,
|
||||
), attn_output.shape
|
||||
|
||||
attn_output = attn_output.transpose(0, 1).contiguous()
|
||||
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)
|
||||
|
@ -747,7 +747,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
|
||||
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
||||
num_frame_masks = 10 * params.time_mask_ratio
|
||||
num_frame_masks = int(10 * params.time_mask_ratio)
|
||||
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
||||
logging.info(
|
||||
f"num_frame_masks: {num_frame_masks}, "
|
||||
|
Loading…
x
Reference in New Issue
Block a user