mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add cr-ctc training for gigispeech
This commit is contained in:
parent
cf796eefed
commit
ca7dbb085e
1
egs/gigaspeech/ASR/zipformer/attention_decoder.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/attention_decoder.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/attention_decoder.py
|
@ -88,6 +88,8 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import GigaSpeechAsrDataModule
|
from asr_datamodule import GigaSpeechAsrDataModule
|
||||||
|
|
||||||
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -274,6 +276,17 @@ def get_decoding_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def post_processing(
|
||||||
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
|
) -> List[Tuple[str, List[str], List[str]]]:
|
||||||
|
new_results = []
|
||||||
|
for key, ref, hyp in results:
|
||||||
|
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||||
|
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
||||||
|
new_results.append((key, new_ref, new_hyp))
|
||||||
|
return new_results
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -567,6 +580,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
|
results = post_processing(results)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
@ -813,14 +827,12 @@ def main():
|
|||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = gigaspeech.test_clean_cuts()
|
test_cuts = gigaspeech.test_cuts()
|
||||||
test_other_cuts = gigaspeech.test_other_cuts()
|
|
||||||
|
|
||||||
test_clean_dl = gigaspeech.test_dataloaders(test_clean_cuts)
|
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
||||||
test_other_dl = gigaspeech.test_dataloaders(test_other_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test"]
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [test_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
1
egs/gigaspeech/ASR/zipformer/label_smoothing.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/label_smoothing.py
|
1
egs/gigaspeech/ASR/zipformer/spec_augment.py
Symbolic link
1
egs/gigaspeech/ASR/zipformer/spec_augment.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/zipformer/spec_augment.py
|
1453
egs/gigaspeech/ASR/zipformer/train_cr.py
Executable file
1453
egs/gigaspeech/ASR/zipformer/train_cr.py
Executable file
File diff suppressed because it is too large
Load Diff
1542
egs/gigaspeech/ASR/zipformer/train_cr_aed.py
Executable file
1542
egs/gigaspeech/ASR/zipformer/train_cr_aed.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -234,10 +235,13 @@ class TransformerDecoder(nn.Module):
|
|||||||
# construct attn_mask for self-attn modules
|
# construct attn_mask for self-attn modules
|
||||||
padding_mask = make_pad_mask(x_lens) # (batch, tgt_len)
|
padding_mask = make_pad_mask(x_lens) # (batch, tgt_len)
|
||||||
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
|
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
|
||||||
attn_mask = torch.logical_or(
|
|
||||||
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
|
with warnings.catch_warnings():
|
||||||
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len)
|
warnings.simplefilter("ignore")
|
||||||
) # (batch, seq_len, seq_len)
|
attn_mask = torch.logical_or(
|
||||||
|
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
|
||||||
|
torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len)
|
||||||
|
) # (batch, seq_len, seq_len)
|
||||||
|
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim)
|
memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim)
|
||||||
@ -367,7 +371,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = attention_dim // num_heads
|
self.head_dim = attention_dim // num_heads
|
||||||
assert self.head_dim * num_heads == attention_dim, (
|
assert self.head_dim * num_heads == attention_dim, (
|
||||||
self.head_dim, num_heads, attention_dim
|
self.head_dim,
|
||||||
|
num_heads,
|
||||||
|
attention_dim,
|
||||||
)
|
)
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.name = None # will be overwritten in training code; for diagnostics.
|
self.name = None # will be overwritten in training code; for diagnostics.
|
||||||
@ -437,15 +443,19 @@ class MultiHeadAttention(nn.Module):
|
|||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
|
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
|
||||||
attn_weights = attn_weights.masked_fill(
|
attn_weights = attn_weights.masked_fill(
|
||||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
|
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||||
|
float("-inf"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == (
|
||||||
attn_mask.shape == (batch, 1, src_len)
|
batch,
|
||||||
or attn_mask.shape == (batch, tgt_len, src_len)
|
tgt_len,
|
||||||
|
src_len,
|
||||||
), attn_mask.shape
|
), attn_mask.shape
|
||||||
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf"))
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
attn_mask.unsqueeze(1), float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
@ -456,7 +466,11 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# (batch * head, tgt_len, head_dim)
|
# (batch * head, tgt_len, head_dim)
|
||||||
attn_output = torch.bmm(attn_weights, v)
|
attn_output = torch.bmm(attn_weights, v)
|
||||||
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape
|
assert attn_output.shape == (
|
||||||
|
batch * num_heads,
|
||||||
|
tgt_len,
|
||||||
|
head_dim,
|
||||||
|
), attn_output.shape
|
||||||
|
|
||||||
attn_output = attn_output.transpose(0, 1).contiguous()
|
attn_output = attn_output.transpose(0, 1).contiguous()
|
||||||
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)
|
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)
|
||||||
|
@ -747,7 +747,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
||||||
num_frame_masks = 10 * params.time_mask_ratio
|
num_frame_masks = int(10 * params.time_mask_ratio)
|
||||||
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
||||||
logging.info(
|
logging.info(
|
||||||
f"num_frame_masks: {num_frame_masks}, "
|
f"num_frame_masks: {num_frame_masks}, "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user