Add cr-ctc training for gigispeech

This commit is contained in:
pkufool 2024-09-21 21:59:18 +08:00
parent cf796eefed
commit ca7dbb085e
8 changed files with 3042 additions and 18 deletions

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/attention_decoder.py

View File

@ -88,6 +88,8 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule from asr_datamodule import GigaSpeechAsrDataModule
from gigaspeech_scoring import asr_text_post_processing
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -274,6 +276,17 @@ def get_decoding_params() -> AttributeDict:
return params return params
def post_processing(
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((key, new_ref, new_hyp))
return new_results
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -567,6 +580,7 @@ def save_results(
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = post_processing(results)
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
@ -813,14 +827,12 @@ def main():
args.return_cuts = True args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args) gigaspeech = GigaSpeechAsrDataModule(args)
test_clean_cuts = gigaspeech.test_clean_cuts() test_cuts = gigaspeech.test_cuts()
test_other_cuts = gigaspeech.test_other_cuts()
test_clean_dl = gigaspeech.test_dataloaders(test_clean_cuts) test_dl = gigaspeech.test_dataloaders(test_cuts)
test_other_dl = gigaspeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["test"]
test_dl = [test_clean_dl, test_other_dl] test_dl = [test_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/label_smoothing.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/spec_augment.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@
import math import math
import warnings
from typing import List, Optional from typing import List, Optional
import k2 import k2
@ -234,10 +235,13 @@ class TransformerDecoder(nn.Module):
# construct attn_mask for self-attn modules # construct attn_mask for self-attn modules
padding_mask = make_pad_mask(x_lens) # (batch, tgt_len) padding_mask = make_pad_mask(x_lens) # (batch, tgt_len)
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
attn_mask = torch.logical_or(
padding_mask.unsqueeze(1), # (batch, 1, seq_len) with warnings.catch_warnings():
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len) warnings.simplefilter("ignore")
) # (batch, seq_len, seq_len) attn_mask = torch.logical_or(
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len)
) # (batch, seq_len, seq_len)
if memory is not None: if memory is not None:
memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim) memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim)
@ -367,7 +371,9 @@ class MultiHeadAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = attention_dim // num_heads self.head_dim = attention_dim // num_heads
assert self.head_dim * num_heads == attention_dim, ( assert self.head_dim * num_heads == attention_dim, (
self.head_dim, num_heads, attention_dim self.head_dim,
num_heads,
attention_dim,
) )
self.dropout = dropout self.dropout = dropout
self.name = None # will be overwritten in training code; for diagnostics. self.name = None # will be overwritten in training code; for diagnostics.
@ -437,15 +443,19 @@ class MultiHeadAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
) )
if attn_mask is not None: if attn_mask is not None:
assert ( assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == (
attn_mask.shape == (batch, 1, src_len) batch,
or attn_mask.shape == (batch, tgt_len, src_len) tgt_len,
src_len,
), attn_mask.shape ), attn_mask.shape
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf")) attn_weights = attn_weights.masked_fill(
attn_mask.unsqueeze(1), float("-inf")
)
attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@ -456,7 +466,11 @@ class MultiHeadAttention(nn.Module):
# (batch * head, tgt_len, head_dim) # (batch * head, tgt_len, head_dim)
attn_output = torch.bmm(attn_weights, v) attn_output = torch.bmm(attn_weights, v)
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape assert attn_output.shape == (
batch * num_heads,
tgt_len,
head_dim,
), attn_output.shape
attn_output = attn_output.transpose(0, 1).contiguous() attn_output = attn_output.transpose(0, 1).contiguous()
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)

View File

@ -747,7 +747,7 @@ def get_model(params: AttributeDict) -> nn.Module:
def get_spec_augment(params: AttributeDict) -> SpecAugment: def get_spec_augment(params: AttributeDict) -> SpecAugment:
num_frame_masks = 10 * params.time_mask_ratio num_frame_masks = int(10 * params.time_mask_ratio)
max_frames_mask_fraction = 0.15 * params.time_mask_ratio max_frames_mask_fraction = 0.15 * params.time_mask_ratio
logging.info( logging.info(
f"num_frame_masks: {num_frame_masks}, " f"num_frame_masks: {num_frame_masks}, "