Merge branch 'k2-fsa:master' into dev/zipformer_lstm

This commit is contained in:
Yifan Yang 2024-07-10 15:21:23 +08:00 committed by GitHub
commit 61e60d90a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1439 additions and 597 deletions

View File

@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer.
| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head |
| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty |
| `zipformer-ctc` | Zipformer | Use auxiliary attention head |
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head | The latest recipe |
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe |
# MMI

View File

@ -1,5 +1,184 @@
## Results
### zipformer (zipformer + CTC/AED)
See <https://github.com/k2-fsa/icefall/pull/1389> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### small-scale model, number of model parameters: 46282107, i.e., 46.3 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-small-ctc-attention-decoder-2024-07-09>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-decoding | 3.04 | 7.04 | --epoch 50 --avg 30 |
| attention-decoder-rescoring-no-ngram | 2.45 | 6.08 | --epoch 50 --avg 30 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# For non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-small \
--full-libri 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--ctc-loss-scale 0.1 \
--attention-decoder-loss-scale 0.9 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--base-lr 0.04 \
--max-duration 1700 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-decoding attention-decoder-rescoring-no-ngram; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 30 \
--exp-dir zipformer/exp-small \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--attention-decoder-loss-scale 0.9 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--max-duration 100 \
--causal 0 \
--num-paths 100 \
--decoding-method $m
done
```
##### medium-scale model, number of model parameters: 89987295, i.e., 90.0 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-ctc-attention-decoder-2024-07-08>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-decoding | 2.46 | 5.57 | --epoch 50 --avg 22 |
| attention-decoder-rescoring-no-ngram | 2.23 | 4.98 | --epoch 50 --avg 22 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--full-libri 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--ctc-loss-scale 0.1 \
--attention-decoder-loss-scale 0.9 \
--max-duration 1200 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-decoding attention-decoder-rescoring-no-ngram; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 22 \
--exp-dir zipformer/exp \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--attention-decoder-loss-scale 0.9 \
--max-duration 100 \
--causal 0 \
--num-paths 100 \
--decoding-method $m
done
```
##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-ctc-attention-decoder-2024-05-26>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-decoding | 2.29 | 5.14 | --epoch 50 --avg 29 |
| attention-decoder-rescoring-no-ngram | 2.1 | 4.57 | --epoch 50 --avg 29 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large \
--full-libri 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--ctc-loss-scale 0.1 \
--attention-decoder-loss-scale 0.9 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 1200 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-decoding attention-decoder-rescoring-no-ngram; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 29 \
--exp-dir zipformer/exp-large \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--attention-decoder-loss-scale 0.9 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 100 \
--causal 0 \
--num-paths 100 \
--decoding-method $m
done
```
### zipformer (zipformer + pruned stateless transducer + CTC)
See <https://github.com/k2-fsa/icefall/pull/1111> for more details.

View File

@ -0,0 +1,573 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional
import k2
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from scaling import penalize_abs_values_gt
from icefall.utils import add_eos, add_sos, make_pad_mask
class AttentionDecoderModel(nn.Module):
"""
Args:
vocab_size (int): Number of classes.
decoder_dim: (int,int): embedding dimension of 2 encoder stacks
attention_dim: (int,int): attention dimension of 2 encoder stacks
num_heads (int, int): number of heads
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int = 512,
num_decoder_layers: int = 6,
attention_dim: int = 512,
num_heads: int = 8,
feedforward_dim: int = 2048,
memory_dim: int = 512,
sos_id: int = 1,
eos_id: int = 1,
dropout: float = 0.1,
ignore_id: int = -1,
label_smoothing: float = 0.1,
):
super().__init__()
self.eos_id = eos_id
self.sos_id = sos_id
self.ignore_id = ignore_id
# For the segment of the warmup period, we let the Embedding
# layer learn something. Then we start to warm up the other encoders.
self.decoder = TransformerDecoder(
vocab_size=vocab_size,
d_model=decoder_dim,
num_decoder_layers=num_decoder_layers,
attention_dim=attention_dim,
num_heads=num_heads,
feedforward_dim=feedforward_dim,
memory_dim=memory_dim,
dropout=dropout,
)
# Used to calculate attention-decoder loss
self.loss_fun = LabelSmoothingLoss(
ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum"
)
def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor):
"""Prepare ys_in_pad and ys_out_pad."""
ys_in = add_sos(ys, sos_id=self.sos_id)
# [B, S+1], start with SOS
ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id)
ys_in_lens = ys_lens + 1
ys_out = add_eos(ys, eos_id=self.eos_id)
# [B, S+1], end with EOS
ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id)
return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64)
def calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys: k2.RaggedTensor,
ys_lens: torch.Tensor,
) -> torch.Tensor:
"""Calculate attention-decoder loss.
Args:
encoder_out: (batch, num_frames, encoder_dim)
encoder_out_lens: (batch,)
token_ids: A list of token id list.
Return: The attention-decoder loss.
"""
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens)
# decoder forward
decoder_out = self.decoder(
x=ys_in_pad,
x_lens=ys_in_lens,
memory=encoder_out,
memory_lens=encoder_out_lens,
)
loss = self.loss_fun(x=decoder_out, target=ys_out_pad)
return loss
def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
token_ids: List[List[int]],
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from attention-decoder.
Args:
encoder_out: (batch, num_frames, encoder_dim)
encoder_out_lens: (batch,)
token_ids: A list of token id list.
Return: A tensor of shape (batch, num_tokens).
"""
ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device)
row_splits = ys.shape.row_splits(1)
ys_lens = row_splits[1:] - row_splits[:-1]
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens)
# decoder forward
decoder_out = self.decoder(
x=ys_in_pad,
x_lens=ys_in_lens,
memory=encoder_out,
memory_lens=encoder_out_lens,
)
batch_size, _, num_classes = decoder_out.size()
nll = nn.functional.cross_entropy(
decoder_out.view(-1, num_classes),
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="none",
)
nll = nll.view(batch_size, -1)
return nll
class TransformerDecoder(nn.Module):
"""Transfomer decoder module.
Args:
vocab_size: output dim
d_model: decoder dimension
num_decoder_layers: number of decoder layers
attention_dim: total dimension of multi head attention
num_heads: number of attention heads
feedforward_dim: hidden dimension of feed_forward module
dropout: dropout rate
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_decoder_layers: int = 6,
attention_dim: int = 512,
num_heads: int = 8,
feedforward_dim: int = 2048,
memory_dim: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
# Absolute positional encoding
self.pos = PositionalEncoding(d_model, dropout_rate=0.1)
self.num_layers = num_decoder_layers
self.layers = nn.ModuleList(
[
DecoderLayer(
d_model=d_model,
attention_dim=attention_dim,
num_heads=num_heads,
feedforward_dim=feedforward_dim,
memory_dim=memory_dim,
dropout=dropout,
)
for _ in range(num_decoder_layers)
]
)
self.output_layer = nn.Linear(d_model, vocab_size)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
memory: Optional[torch.Tensor] = None,
memory_lens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (batch, tgt_len).
x_lens: A tensor of shape (batch,) containing the number of tokens in `x`
before padding.
memory:
Memory sequence of shape (batch, src_len, memory_dim).
memory_lens:
A tensor of shape (batch,) containing the number of frames in
`memory` before padding.
Returns:
Decoded token logits before softmax (batch, tgt_len, vocab_size)
"""
x = self.embed(x) # (batch, tgt_len, embed_dim)
x = self.pos(x) # (batch, tgt_len, embed_dim)
x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim)
# construct attn_mask for self-attn modules
padding_mask = make_pad_mask(x_lens) # (batch, tgt_len)
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
attn_mask = torch.logical_or(
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len)
) # (batch, seq_len, seq_len)
if memory is not None:
memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim)
# construct memory_attn_mask for cross-attn modules
memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len)
memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len)
else:
memory_attn_mask = None
for i, mod in enumerate(self.layers):
x = mod(
x,
attn_mask=attn_mask,
memory=memory,
memory_attn_mask=memory_attn_mask,
)
x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size)
x = self.output_layer(x)
return x
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
d_model: equal to decoder_dim, total dimension of the decoder
attention_dim: total dimension of multi head attention
num_heads: number of attention heads
feedforward_dim: hidden dimension of feed_forward module
dropout: dropout rate
"""
def __init__(
self,
d_model: int = 512,
attention_dim: int = 512,
num_heads: int = 8,
feedforward_dim: int = 2048,
memory_dim: int = 512,
dropout: float = 0.1,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.norm_self_attn = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(
d_model, attention_dim, num_heads, dropout=0.0
)
self.norm_src_attn = nn.LayerNorm(d_model)
self.src_attn = MultiHeadAttention(
d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0
)
self.norm_ff = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, feedforward_dim),
Swish(),
nn.Dropout(dropout),
nn.Linear(feedforward_dim, d_model),
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
memory: Optional[torch.Tensor] = None,
memory_attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: Input sequence of shape (seq_len, batch, embed_dim).
attn_mask: A binary mask for self-attention module indicating which
elements will be filled with -inf.
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
memory: Memory sequence of shape (seq_len, batch, memory_dim).
memory_attn_mask: A binary mask for cross-attention module indicating which
elements will be filled with -inf.
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
"""
# self-attn module
qkv = self.norm_self_attn(x)
self_attn_out = self.self_attn(
query=qkv, key=qkv, value=qkv, attn_mask=attn_mask
)
x = x + self.dropout(self_attn_out)
# cross-attn module
q = self.norm_src_attn(x)
src_attn_out = self.src_attn(
query=q, key=memory, value=memory, attn_mask=memory_attn_mask
)
x = x + self.dropout(src_attn_out)
# feed-forward module
x = x + self.dropout(self.feed_forward(self.norm_ff(x)))
return x
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
embed_dim: total dimension of the model.
attention_dim: dimension in the attention module, but must be a multiple of num_heads.
num_heads: number of parallel attention heads.
memory_dim: dimension of memory embedding, optional.
dropout: a Dropout layer on attn_output_weights.
"""
def __init__(
self,
embed_dim: int,
attention_dim: int,
num_heads: int,
memory_dim: Optional[int] = None,
dropout: float = 0.0,
):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.attention_dim = attention_dim
self.num_heads = num_heads
self.head_dim = attention_dim // num_heads
assert self.head_dim * num_heads == attention_dim, (
self.head_dim, num_heads, attention_dim
)
self.dropout = dropout
self.name = None # will be overwritten in training code; for diagnostics.
self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True)
self.linear_k = nn.Linear(
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
)
self.linear_v = nn.Linear(
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
)
self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute dot product attention.
Args:
query: Query tensor of shape (tgt_len, batch, embed_dim).
key: Key tensor of shape (src_len, batch, embed_dim or memory_dim).
value: Value tensor of shape (src_len, batch, embed_dim or memory_dim).
key_padding_mask: A binary mask indicating which elements are padding.
Its shape is (batch, src_len).
attn_mask: A binary mask indicating which elements will be filled with -inf.
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
Returns:
Output tensor of shape (tgt_len, batch, embed_dim).
"""
num_heads = self.num_heads
head_dim = self.head_dim
tgt_len, batch, _ = query.shape
src_len = key.shape[0]
q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim)
k = self.linear_k(key) # (src_len, batch, num_heads * head_dim)
v = self.linear_v(value) # (src_len, batch, num_heads * head_dim)
q = q.reshape(tgt_len, batch, num_heads, head_dim)
q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim)
k = k.reshape(src_len, batch, num_heads, head_dim)
k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len)
v = v.reshape(src_len, batch, num_heads, head_dim)
v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1)
# Note: could remove the scaling operation when using ScaledAdam
# (batch, head, tgt_len, src_len)
attn_weights = torch.matmul(q, k) / math.sqrt(head_dim)
# From zipformer.py:
# This is a harder way of limiting the attention scores to not be too large.
# It incurs a penalty if any of them has an absolute value greater than 50.0.
# this should be outside the normal range of the attention scores. We use
# this mechanism instead of, say, a limit on entropy, because once the entropy
# gets very small gradients through the softmax can become very small, and
# some mechanisms like that become ineffective.
attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
)
if attn_mask is not None:
assert (
attn_mask.shape == (batch, 1, src_len)
or attn_mask.shape == (batch, tgt_len, src_len)
), attn_mask.shape
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf"))
attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
# (batch * head, tgt_len, head_dim)
attn_output = torch.bmm(attn_weights, v)
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape
attn_output = attn_output.transpose(0, 1).contiguous()
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)
# (batch, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class PositionalEncoding(nn.Module):
"""Positional encoding.
Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35.
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1)]
return self.dropout(x)
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def subsequent_mask(size, device="cpu", dtype=torch.bool):
"""Create mask for subsequent steps (size, size).
:param int size: size of mask
:param str device: "cpu" or "cuda" or torch.Tensor.device
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=dtype)
return torch.tril(ret, out=ret)
def _test_attention_decoder_model():
m = AttentionDecoderModel(
vocab_size=500,
decoder_dim=512,
num_decoder_layers=6,
attention_dim=512,
num_heads=8,
feedforward_dim=2048,
memory_dim=384,
dropout=0.1,
sos_id=1,
eos_id=1,
ignore_id=-1,
)
num_param = sum([p.numel() for p in m.parameters()])
print(f"Number of model parameters: {num_param}")
m.eval()
encoder_out = torch.randn(2, 50, 384)
encoder_out_lens = torch.full((2,), 50)
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
nll = m.nll(encoder_out, encoder_out_lens, token_ids)
print(nll)
if __name__ == "__main__":
_test_attention_decoder_model()

View File

@ -73,6 +73,29 @@ Usage:
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring
(6) attention-decoder-rescoring-no-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--decoding-method attention-decoder-rescoring-no-ngram
(7) attention-decoder-rescoring-with-ngram
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--use-attention-decoder 1 \
--max-duration 100 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method attention-decoder-rescoring-with-ngram
"""
@ -101,6 +124,8 @@ from icefall.decode import (
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder_no_ngram,
rescore_with_attention_decoder_with_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
@ -212,6 +237,10 @@ def get_parser():
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder.
""",
)
@ -406,6 +435,26 @@ def decode_one_batch(
key = "ctc-decoding"
return {key: hyps}
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
best_path_dict = rescore_with_attention_decoder_no_ngram(
lattice=lattice,
num_paths=params.num_paths,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
ans = dict()
for a_scale_str, best_path in best_path_dict.items():
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
@ -446,6 +495,7 @@ def decode_one_batch(
assert params.decoding_method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
@ -466,6 +516,21 @@ def decode_one_batch(
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.decoding_method == "attention-decoder-rescoring-with-ngram":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_attention_decoder_with_ngram(
lattice=rescored_lattice,
num_paths=params.num_paths,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
@ -564,12 +629,21 @@ def save_results(
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
if params.decoding_method in (
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring"
):
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
@ -577,8 +651,8 @@ def save_results(
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
@ -616,6 +690,8 @@ def main():
"nbest-rescoring",
"whole-lattice-rescoring",
"nbest-oracle",
"attention-decoder-rescoring-no-ngram",
"attention-decoder-rescoring-with-ngram",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -654,8 +730,10 @@ def main():
params.vocab_size = num_classes
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = 0
params.eos_id = 1
params.sos_id = 1
if params.decoding_method == "ctc-decoding":
if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
@ -679,6 +757,7 @@ def main():
if params.decoding_method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
@ -710,7 +789,9 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":
if params.decoding_method in [
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram"
]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)

View File

@ -404,6 +404,7 @@ def main():
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
@ -466,8 +467,6 @@ def main():
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg

View File

@ -0,0 +1,109 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
assert reduction in ("none", "sum", "mean"), reduction
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use target[ignored] = 0 here
target = torch.where(ignored, torch.zeros_like(target), target)
true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
#
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use true_dist[ignored] = 0 here
true_dist = torch.where(
ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
torch.zeros_like(true_dist),
true_dist,
)
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -34,11 +34,13 @@ class AsrModel(nn.Module):
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
attention_decoder: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
use_attention_decoder: bool = False,
):
"""A joint CTC & Transducer ASR model.
@ -70,6 +72,8 @@ class AsrModel(nn.Module):
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
use_attention_decoder:
Whether use attention-decoder head. Default: False.
"""
super().__init__()
@ -111,6 +115,12 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1),
)
self.use_attention_decoder = use_attention_decoder
if use_attention_decoder:
self.attention_decoder = attention_decoder
else:
assert attention_decoder is None
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -286,7 +296,7 @@ class AsrModel(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -308,7 +318,7 @@ class AsrModel(nn.Module):
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
@ -322,6 +332,8 @@ class AsrModel(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
device = x.device
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
@ -333,7 +345,7 @@ class AsrModel(nn.Module):
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(x.device),
y=y.to(device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
@ -355,4 +367,14 @@ class AsrModel(nn.Module):
else:
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss
if self.use_attention_decoder:
attention_decoder_loss = self.attention_decoder.calc_att_loss(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
ys=y.to(device),
ys_lens=y_lens.to(device),
)
else:
attention_decoder_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss

View File

@ -81,6 +81,15 @@ Usage of this script:
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(5) attention-decoder-rescoring-no-ngram
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method attention-decoder-rescoring-no-ngram \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
@ -100,6 +109,7 @@ from train import add_model_arguments, get_model, get_params
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder_no_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
@ -172,6 +182,8 @@ def get_parser():
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + whole-lattice n-gram LM rescoring.
(4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
lattice, rescore them with the attention decoder.
""",
)
@ -276,6 +288,7 @@ def main():
token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
assert params.blank_id == 0
logging.info(f"{params}")
@ -333,16 +346,13 @@ def main():
dtype=torch.int32,
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
max_token_id = params.vocab_size - 1
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=H,
@ -354,9 +364,23 @@ def main():
subsampling_factor=params.subsampling_factor,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
else:
logging.info("Use attention decoder rescoring without ngram")
best_path_dict = rescore_with_attention_decoder_no_ngram(
lattice=lattice,
num_paths=params.num_paths,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))
token_ids = get_texts(best_path)
hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [
@ -430,7 +454,7 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
if params.method == "ctc-decoding":
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()

View File

@ -48,6 +48,8 @@ It supports training with:
- transducer loss (default), with `--use-transducer True --use-ctc False`
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
- ctc loss & attention decoder loss, no transducer loss,
with `--use-transducer False --use-ctc True --use-attention-decoder True`
"""
@ -66,6 +68,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from attention_decoder import AttentionDecoderModel
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -221,6 +224,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--attention-decoder-dim",
type=int,
default=512,
help="""Dimension used in the attention decoder""",
)
parser.add_argument(
"--attention-decoder-num-layers",
type=int,
default=6,
help="""Number of transformer layers used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-attention-dim",
type=int,
default=512,
help="""Attention dimension used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-num-heads",
type=int,
default=8,
help="""Number of attention heads used in attention decoder""",
)
parser.add_argument(
"--attention-decoder-feedforward-dim",
type=int,
default=2048,
help="""Feedforward dimension used in attention decoder""",
)
parser.add_argument(
"--causal",
type=str2bool,
@ -259,6 +297,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use CTC head.",
)
parser.add_argument(
"--use-attention-decoder",
type=str2bool,
default=False,
help="If True, use attention-decoder head.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -404,6 +449,13 @@ def get_parser():
help="Scale for CTC loss.",
)
parser.add_argument(
"--attention-decoder-loss-scale",
type=float,
default=0.8,
help="Scale for attention-decoder loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -528,6 +580,9 @@ def get_params() -> AttributeDict:
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
# parameters for attention-decoder
"ignore_id": -1,
"label_smoothing": 0.1,
"warm_step": 2000,
"env_info": get_env_info(),
}
@ -600,6 +655,23 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
decoder = AttentionDecoderModel(
vocab_size=params.vocab_size,
decoder_dim=params.attention_decoder_dim,
num_decoder_layers=params.attention_decoder_num_layers,
attention_dim=params.attention_decoder_attention_dim,
num_heads=params.attention_decoder_num_heads,
feedforward_dim=params.attention_decoder_feedforward_dim,
memory_dim=max(_to_int_tuple(params.encoder_dim)),
sos_id=params.sos_id,
eos_id=params.eos_id,
ignore_id=params.ignore_id,
label_smoothing=params.label_smoothing,
)
return decoder
def get_model(params: AttributeDict) -> nn.Module:
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
@ -617,16 +689,23 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None
joiner = None
if params.use_attention_decoder:
attention_decoder = get_attention_decoder_model(params)
else:
attention_decoder = None
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,
attention_decoder=attention_decoder,
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
use_ctc=params.use_ctc,
use_attention_decoder=params.use_attention_decoder,
)
return model
@ -789,7 +868,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -819,6 +898,9 @@ def compute_loss(
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
@ -833,6 +915,8 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
return loss, info
@ -1112,10 +1196,16 @@ def run(rank, world_size, args):
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
if not params.use_attention_decoder:
params.ctc_loss_scale = 1.0
else:
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
params.ctc_loss_scale, params.attention_decoder_loss_scale
)
logging.info(params)

View File

@ -43,6 +43,61 @@ Fine-tuned models, training logs, decoding logs, tensorboard and decoding result
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
### Multi Chinese datasets char-based training results (streaming) on zipformer large model
#### Streaming (with CTC head)
The training command for large model (num of params : ~160M):
Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features.
```
./zipformer/train.py \
--world-size 8 \
--num-epochs 20 \
--use-fp16 1 \
--max-duration 1200 \
--num-workers 8 \
--use-ctc 1 \
--exp-dir zipformer/exp-large \
--causal 1 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 768,1024,1536,2048,1536,768 \
--encoder-dim 256,384,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
```
The decoding command for transducer greedy search:
```
./zipformer/decode.py \
--epoch 999 \
--avg 1 \
--causal 1 \
--use-averaged-model False \
--chunk_size -1
--left-context-frames -1 \
--use-ctc 1 \
--exp-dir zipformer/exp-large \
--max-duration 1200 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 768,1024,1536,2048,1536,768 \
--encoder-dim 256,384,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
```
Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| CTC Greedy Streaming | 26.50 | 28.10| 1.71 | 1.97| 3.89| 4.06 | 17.23 | 3.69 | 2.87 | 8.14 | 3.61 |9.51 | 6.11 | 8.13 | 10.62 |
| CTC Greedy Offline | 23.47 | 25.02 | 1.39 | 1.50 | 3.15 | 3.41 | 15.14 | 3.07 | 2.37 | 6.06 | 2.90 | 7.13 | 5.40 | 6.52 | 9.64 |
| Transducer Greedy Offline | 23.16 | 24.78 | 1.33 | 1.38 | 3.06 | 3.23 | 15.36 | 2.54 | 2.09 | 5.24 | 2.28 | 6.26 | 4.87 | 6.26 | 7.07 |
| Transducer Greedy Streaming | 26.83|28.74 | 1.75 | 1.91 | 3.84 | 4.12 | 17.83 | 3.23 | 2.71 | 7.31 | 3.16 | 8.69 | 5.71 | 7.91 | 8.54 |
Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-large
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model

View File

@ -1,247 +0,0 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import re
from pathlib import Path
from typing import Dict, List
import lhotse
from lhotse import CutSet, load_manifest_lazy
class MultiDataset:
def __init__(self, fbank_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files:
- aishell_cuts_train.jsonl.gz
- aishell2_cuts_train.jsonl.gz
- aishell4_cuts_train_L.jsonl.gz
- aishell4_cuts_train_M.jsonl.gz
- aishell4_cuts_train_S.jsonl.gz
- alimeeting-far_cuts_train.jsonl.gz
- magicdata_cuts_train.jsonl.gz
- primewords_cuts_train.jsonl.gz
- stcmds_cuts_train.jsonl.gz
- thchs_30_cuts_train.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
- wenetspeech/cuts_L_fixed.jsonl.gz
"""
self.fbank_dir = Path(fbank_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# THCHS-30
logging.info("Loading THCHS-30 in lazy mode")
thchs_30_cuts = load_manifest_lazy(
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
)
# AISHELL-1
logging.info("Loading Aishell-1 in lazy mode")
aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
# AISHELL-4
logging.info("Loading Aishell-4 in lazy mode")
aishell_4_L_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
)
aishell_4_M_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
)
aishell_4_S_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
)
# ST-CMDS
logging.info("Loading ST-CMDS in lazy mode")
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
# Primewords
logging.info("Loading Primewords in lazy mode")
primewords_cuts = load_manifest_lazy(
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData in lazy mode")
magicdata_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting in lazy mode")
alimeeting_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech in lazy mode")
kespeech_1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
)
kespeech_2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
)
return CutSet.mux(
thchs_30_cuts,
aishell_cuts,
aishell_2_cuts,
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
alimeeting_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
wenetspeech_L_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
weights=[
len(thchs_30_cuts),
len(aishell_cuts),
len(aishell_2_cuts),
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(alimeeting_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(wenetspeech_L_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
],
)
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# WeNetSpeech
logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)
return wenetspeech_dev_cuts
def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
# AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
)
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
)
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
# AISHELL-4
logging.info("Loading Aishell-4 TEST set in lazy mode")
aishell4_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting set in lazy mode")
alimeeting_test_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
)
alimeeting_eval_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData set in lazy mode")
magicdata_test_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
)
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech set in lazy mode")
kespeech_test_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
)
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech set in lazy mode")
wenetspeech_test_meeting_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
)
wenetspeech_test_net_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)
return {
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "ali-meeting_test": alimeeting_test_cuts,
# "ali-meeting_eval": alimeeting_eval_cuts,
# "aishell-4_test": aishell4_test_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
}

View File

@ -0,0 +1 @@
../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py

View File

@ -46,7 +46,7 @@ import torch.nn as nn
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting
from icefall.checkpoint import (
average_checkpoints,
@ -367,21 +367,18 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = list(ref_text.replace(" ", ""))
hyp_words = list("".join(hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
ref_text = normalize_text_alimeeting(ref_text)
hyp_text = "".join(hyp_words)
this_batch.append((cut_id, ref_text, hyp_text))
results[name].extend(this_batch)
@ -583,7 +580,7 @@ def main():
data_module = AsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
test_sets_cuts = multi_dataset.test_cuts()
test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()}
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2

View File

@ -118,7 +118,7 @@ from beam_search import (
)
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting
from icefall.checkpoint import (
average_checkpoints,
@ -532,7 +532,6 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
@ -548,6 +547,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
hyp_text = "".join(hyp_words)
this_batch.append((cut_id, ref_text, hyp_text))
@ -795,7 +795,7 @@ def main():
)
return T > 0
test_sets_cuts = multi_dataset.test_cuts()
test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()}
test_sets = test_sets_cuts.keys()
test_dl = [

View File

@ -1,316 +0,0 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import re
from pathlib import Path
from typing import Dict, List
import lhotse
from lhotse import CutSet, load_manifest_lazy
class MultiDataset:
def __init__(self, fbank_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files:
- aidatatang_cuts_train.jsonl.gz
- aishell_cuts_train.jsonl.gz
- aishell2_cuts_train.jsonl.gz
- aishell4_cuts_train_L.jsonl.gz
- aishell4_cuts_train_M.jsonl.gz
- aishell4_cuts_train_S.jsonl.gz
- alimeeting-far_cuts_train.jsonl.gz
- magicdata_cuts_train.jsonl.gz
- primewords_cuts_train.jsonl.gz
- stcmds_cuts_train.jsonl.gz
- thchs_30_cuts_train.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
- wenetspeech/cuts_L.jsonl.gz
"""
self.fbank_dir = Path(fbank_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# THCHS-30
logging.info("Loading THCHS-30 in lazy mode")
thchs_30_cuts = load_manifest_lazy(
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
)
# AISHELL-1
logging.info("Loading Aishell-1 in lazy mode")
aishell_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
# AISHELL-4
logging.info("Loading Aishell-4 in lazy mode")
aishell_4_L_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
)
aishell_4_M_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
)
aishell_4_S_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
)
# ST-CMDS
logging.info("Loading ST-CMDS in lazy mode")
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
# Primewords
logging.info("Loading Primewords in lazy mode")
primewords_cuts = load_manifest_lazy(
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData in lazy mode")
magicdata_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
)
# Aidatatang_200zh
logging.info("Loading Aidatatang_200zh in lazy mode")
aidatatang_200zh_cuts = load_manifest_lazy(
self.fbank_dir / "aidatatang_cuts_train.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting in lazy mode")
alimeeting_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech in lazy mode")
kespeech_1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
)
kespeech_2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
)
return CutSet.mux(
thchs_30_cuts,
aishell_cuts,
aishell_2_cuts,
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
aidatatang_200zh_cuts,
alimeeting_cuts,
wenetspeech_L_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
weights=[
len(thchs_30_cuts),
len(aishell_cuts),
len(aishell_2_cuts),
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(aidatatang_200zh_cuts),
len(alimeeting_cuts),
len(wenetspeech_L_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
],
)
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# Aidatatang_200zh
logging.info("Loading Aidatatang_200zh DEV set in lazy mode")
aidatatang_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz"
)
# AISHELL
logging.info("Loading Aishell DEV set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 DEV set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting DEV set in lazy mode")
alimeeting_dev_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData DEV set in lazy mode")
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech DEV set in lazy mode")
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
)
return wenetspeech_dev_cuts
# return [
# aidatatang_dev_cuts,
# aishell_dev_cuts,
# aishell2_dev_cuts,
# alimeeting_dev_cuts,
# magicdata_dev_cuts,
# kespeech_dev_phase1_cuts,
# kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts,
# ]
def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
# Aidatatang_200zh
logging.info("Loading Aidatatang_200zh set in lazy mode")
aidatatang_test_cuts = load_manifest_lazy(
self.fbank_dir / "aidatatang_cuts_test.jsonl.gz"
)
aidatatang_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz"
)
# AISHELL
logging.info("Loading Aishell set in lazy mode")
aishell_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
)
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
)
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
# AISHELL-4
logging.info("Loading Aishell-4 TEST set in lazy mode")
aishell4_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting set in lazy mode")
alimeeting_test_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
)
alimeeting_eval_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData set in lazy mode")
magicdata_test_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
)
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech set in lazy mode")
kespeech_test_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
)
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech set in lazy mode")
wenetspeech_test_meeting_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
)
wenetspeech_test_net_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
)
return {
"aidatatang_test": aidatatang_test_cuts,
"aidatatang_dev": aidatatang_dev_cuts,
"alimeeting_test": alimeeting_test_cuts,
"alimeeting_eval": alimeeting_eval_cuts,
"aishell_test": aishell_test_cuts,
"aishell_dev": aishell_dev_cuts,
"aishell-2_test": aishell2_test_cuts,
"aishell-2_dev": aishell2_dev_cuts,
"aishell-4": aishell4_test_cuts,
"magicdata_test": magicdata_test_cuts,
"magicdata_dev": magicdata_dev_cuts,
"kespeech-asr_test": kespeech_test_cuts,
"kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
"kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
"wenetspeech-net_test": wenetspeech_test_net_cuts,
"wenetspeech_dev": wenetspeech_dev_cuts,
}

View File

@ -0,0 +1 @@
../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py

View File

@ -539,6 +539,43 @@ def get_params() -> AttributeDict:
return params
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
@ -788,6 +825,9 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)

View File

@ -114,7 +114,8 @@ def extract_hyp_ref_wavname(filename):
for line in f:
if "ref" in line:
ref = line.split("ref=")[1].strip()
ref = ref[2:-2]
if ref[0] == "[":
ref = ref[2:-2]
list_elements = ref.split("', '")
ref = "".join(list_elements)
refs.append(ref)

View File

@ -1083,6 +1083,238 @@ def rescore_with_attention_decoder(
return ans
def rescore_with_attention_decoder_with_ngram(
lattice: k2.Fsa,
num_paths: int,
attention_decoder: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
attention_decoder:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
encoder_out:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(N, T, C)`.
encoder_out_lens:
Length of encoder outputs, with shape of `(N,)`.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
try:
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
logging.info(f"num_paths before decreasing: {num_paths}")
num_paths = int(num_paths / 2)
if loop_count >= max_loop_count or num_paths <= 0:
logging.info("Return None as the resulting lattice is too large.")
return None
logging.info(
"This OOM is not an error. You can ignore it. "
"If your model does not converge well, or --max-duration "
"is too large, or the input sound file is difficult to "
"decode, you will meet this exception."
)
logging.info(f"num_paths after decreasing: {num_paths}")
loop_count += 1
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
# The `tokens` attribute is set inside `compile_hlg.py`
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# the shape of memory is (T, N, C), so we use axis=1 here
expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
# remove axis corresponding to states.
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
token_ids = tokens.tolist()
nll = attention_decoder.nll(
encoder_out=expanded_encoder_out,
encoder_out_lens=expanded_encoder_out_lens,
token_ids=token_ids,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
attention_scale_list = [attention_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path
return ans
def rescore_with_attention_decoder_no_ngram(
lattice: k2.Fsa,
num_paths: int,
attention_decoder: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
nbest_scale: float = 1.0,
attention_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
attention_decoder:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
encoder_out:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(N, T, C)`.
encoder_out_lens:
Length of encoder outputs, with shape of `(N,)`.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
# path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
# Note that labels, aux_labels and scores contains 0s and -1s.
# The last entry in each sublist is -1.
# The axes are [path][token_id]
labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0)
aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0)
scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0)
# Remove -1 from labels as we will use it to construct a linear FSA
labels = labels.remove_values_eq(-1)
fsa = k2.linear_fsa(labels)
fsa.aux_labels = aux_labels.values
# utt_to_path_shape has axes [utt][path]
utt_to_path_shape = path.shape.get_layer(0)
scores = k2.RaggedTensor(utt_to_path_shape, scores.sum())
path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long)
# the shape of memory is (N, T, C), so we use axis=0 here
expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
token_ids = aux_labels.remove_values_leq(0).tolist()
nll = attention_decoder.nll(
encoder_out=expanded_encoder_out,
encoder_out_lens=expanded_encoder_out_lens,
token_ids=token_ids,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0]
else:
attention_scale_list = [attention_scale]
ans = dict()
for a_scale in attention_scale_list:
tot_scores = scores.values + a_scale * attention_scores
ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(fsa, max_indexes)
key = f"attention_scale_{a_scale}"
ans[key] = best_path
return ans
def rescore_with_rnn_lm(
lattice: k2.Fsa,
num_paths: int,

View File

@ -28,5 +28,6 @@ multi_quantization
onnx
onnxmltools
onnxruntime
onnxconverter_common
kaldifst
kaldi-decoder