mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Merge branch 'k2-fsa:master' into dev/zipformer_lstm
This commit is contained in:
commit
61e60d90a6
@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer.
|
||||
| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head |
|
||||
| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty |
|
||||
| `zipformer-ctc` | Zipformer | Use auxiliary attention head |
|
||||
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head | The latest recipe |
|
||||
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe |
|
||||
|
||||
# MMI
|
||||
|
||||
|
@ -1,5 +1,184 @@
|
||||
## Results
|
||||
|
||||
### zipformer (zipformer + CTC/AED)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1389> for more details.
|
||||
|
||||
[zipformer](./zipformer)
|
||||
|
||||
#### Non-streaming
|
||||
|
||||
##### small-scale model, number of model parameters: 46282107, i.e., 46.3 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-small-ctc-attention-decoder-2024-07-09>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| ctc-decoding | 3.04 | 7.04 | --epoch 50 --avg 30 |
|
||||
| attention-decoder-rescoring-no-ngram | 2.45 | 6.08 | --epoch 50 --avg 30 |
|
||||
|
||||
The training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
# For non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-small \
|
||||
--full-libri 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 1 \
|
||||
--ctc-loss-scale 0.1 \
|
||||
--attention-decoder-loss-scale 0.9 \
|
||||
--num-encoder-layers 2,2,2,2,2,2 \
|
||||
--feedforward-dim 512,768,768,768,768,768 \
|
||||
--encoder-dim 192,256,256,256,256,256 \
|
||||
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||
--base-lr 0.04 \
|
||||
--max-duration 1700 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in ctc-decoding attention-decoder-rescoring-no-ngram; do
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 30 \
|
||||
--exp-dir zipformer/exp-small \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 1 \
|
||||
--attention-decoder-loss-scale 0.9 \
|
||||
--num-encoder-layers 2,2,2,2,2,2 \
|
||||
--feedforward-dim 512,768,768,768,768,768 \
|
||||
--encoder-dim 192,256,256,256,256,256 \
|
||||
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||
--max-duration 100 \
|
||||
--causal 0 \
|
||||
--num-paths 100 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
##### medium-scale model, number of model parameters: 89987295, i.e., 90.0 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-ctc-attention-decoder-2024-07-08>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| ctc-decoding | 2.46 | 5.57 | --epoch 50 --avg 22 |
|
||||
| attention-decoder-rescoring-no-ngram | 2.23 | 4.98 | --epoch 50 --avg 22 |
|
||||
|
||||
The training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
# For non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp \
|
||||
--full-libri 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 1 \
|
||||
--ctc-loss-scale 0.1 \
|
||||
--attention-decoder-loss-scale 0.9 \
|
||||
--max-duration 1200 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in ctc-decoding attention-decoder-rescoring-no-ngram; do
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 22 \
|
||||
--exp-dir zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 1 \
|
||||
--attention-decoder-loss-scale 0.9 \
|
||||
--max-duration 100 \
|
||||
--causal 0 \
|
||||
--num-paths 100 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-ctc-attention-decoder-2024-05-26>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| ctc-decoding | 2.29 | 5.14 | --epoch 50 --avg 29 |
|
||||
| attention-decoder-rescoring-no-ngram | 2.1 | 4.57 | --epoch 50 --avg 29 |
|
||||
|
||||
The training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
# For non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
--full-libri 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 1 \
|
||||
--ctc-loss-scale 0.1 \
|
||||
--attention-decoder-loss-scale 0.9 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--max-duration 1200 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in ctc-decoding attention-decoder-rescoring-no-ngram; do
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 29 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 1 \
|
||||
--attention-decoder-loss-scale 0.9 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--max-duration 100 \
|
||||
--causal 0 \
|
||||
--num-paths 100 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
|
||||
### zipformer (zipformer + pruned stateless transducer + CTC)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1111> for more details.
|
||||
|
573
egs/librispeech/ASR/zipformer/attention_decoder.py
Normal file
573
egs/librispeech/ASR/zipformer/attention_decoder.py
Normal file
@ -0,0 +1,573 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
from icefall.utils import add_eos, add_sos, make_pad_mask
|
||||
|
||||
|
||||
class AttentionDecoderModel(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
vocab_size (int): Number of classes.
|
||||
decoder_dim: (int,int): embedding dimension of 2 encoder stacks
|
||||
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
||||
num_heads (int, int): number of heads
|
||||
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int = 512,
|
||||
num_decoder_layers: int = 6,
|
||||
attention_dim: int = 512,
|
||||
num_heads: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
sos_id: int = 1,
|
||||
eos_id: int = 1,
|
||||
dropout: float = 0.1,
|
||||
ignore_id: int = -1,
|
||||
label_smoothing: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.eos_id = eos_id
|
||||
self.sos_id = sos_id
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
# For the segment of the warmup period, we let the Embedding
|
||||
# layer learn something. Then we start to warm up the other encoders.
|
||||
self.decoder = TransformerDecoder(
|
||||
vocab_size=vocab_size,
|
||||
d_model=decoder_dim,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
attention_dim=attention_dim,
|
||||
num_heads=num_heads,
|
||||
feedforward_dim=feedforward_dim,
|
||||
memory_dim=memory_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# Used to calculate attention-decoder loss
|
||||
self.loss_fun = LabelSmoothingLoss(
|
||||
ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum"
|
||||
)
|
||||
|
||||
def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor):
|
||||
"""Prepare ys_in_pad and ys_out_pad."""
|
||||
ys_in = add_sos(ys, sos_id=self.sos_id)
|
||||
# [B, S+1], start with SOS
|
||||
ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id)
|
||||
ys_in_lens = ys_lens + 1
|
||||
|
||||
ys_out = add_eos(ys, eos_id=self.eos_id)
|
||||
# [B, S+1], end with EOS
|
||||
ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id)
|
||||
|
||||
return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64)
|
||||
|
||||
def calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys: k2.RaggedTensor,
|
||||
ys_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate attention-decoder loss.
|
||||
Args:
|
||||
encoder_out: (batch, num_frames, encoder_dim)
|
||||
encoder_out_lens: (batch,)
|
||||
token_ids: A list of token id list.
|
||||
|
||||
Return: The attention-decoder loss.
|
||||
"""
|
||||
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens)
|
||||
|
||||
# decoder forward
|
||||
decoder_out = self.decoder(
|
||||
x=ys_in_pad,
|
||||
x_lens=ys_in_lens,
|
||||
memory=encoder_out,
|
||||
memory_lens=encoder_out_lens,
|
||||
)
|
||||
|
||||
loss = self.loss_fun(x=decoder_out, target=ys_out_pad)
|
||||
return loss
|
||||
|
||||
def nll(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
) -> torch.Tensor:
|
||||
"""Compute negative log likelihood(nll) from attention-decoder.
|
||||
Args:
|
||||
encoder_out: (batch, num_frames, encoder_dim)
|
||||
encoder_out_lens: (batch,)
|
||||
token_ids: A list of token id list.
|
||||
|
||||
Return: A tensor of shape (batch, num_tokens).
|
||||
"""
|
||||
ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device)
|
||||
row_splits = ys.shape.row_splits(1)
|
||||
ys_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens)
|
||||
|
||||
# decoder forward
|
||||
decoder_out = self.decoder(
|
||||
x=ys_in_pad,
|
||||
x_lens=ys_in_lens,
|
||||
memory=encoder_out,
|
||||
memory_lens=encoder_out_lens,
|
||||
)
|
||||
|
||||
batch_size, _, num_classes = decoder_out.size()
|
||||
nll = nn.functional.cross_entropy(
|
||||
decoder_out.view(-1, num_classes),
|
||||
ys_out_pad.view(-1),
|
||||
ignore_index=self.ignore_id,
|
||||
reduction="none",
|
||||
)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
"""Transfomer decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
d_model: decoder dimension
|
||||
num_decoder_layers: number of decoder layers
|
||||
attention_dim: total dimension of multi head attention
|
||||
num_heads: number of attention heads
|
||||
feedforward_dim: hidden dimension of feed_forward module
|
||||
dropout: dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int = 512,
|
||||
num_decoder_layers: int = 6,
|
||||
attention_dim: int = 512,
|
||||
num_heads: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
|
||||
|
||||
# Absolute positional encoding
|
||||
self.pos = PositionalEncoding(d_model, dropout_rate=0.1)
|
||||
|
||||
self.num_layers = num_decoder_layers
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderLayer(
|
||||
d_model=d_model,
|
||||
attention_dim=attention_dim,
|
||||
num_heads=num_heads,
|
||||
feedforward_dim=feedforward_dim,
|
||||
memory_dim=memory_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(num_decoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.output_layer = nn.Linear(d_model, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
memory: Optional[torch.Tensor] = None,
|
||||
memory_lens: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor of shape (batch, tgt_len).
|
||||
x_lens: A tensor of shape (batch,) containing the number of tokens in `x`
|
||||
before padding.
|
||||
memory:
|
||||
Memory sequence of shape (batch, src_len, memory_dim).
|
||||
memory_lens:
|
||||
A tensor of shape (batch,) containing the number of frames in
|
||||
`memory` before padding.
|
||||
|
||||
Returns:
|
||||
Decoded token logits before softmax (batch, tgt_len, vocab_size)
|
||||
"""
|
||||
x = self.embed(x) # (batch, tgt_len, embed_dim)
|
||||
x = self.pos(x) # (batch, tgt_len, embed_dim)
|
||||
|
||||
x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim)
|
||||
|
||||
# construct attn_mask for self-attn modules
|
||||
padding_mask = make_pad_mask(x_lens) # (batch, tgt_len)
|
||||
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
|
||||
attn_mask = torch.logical_or(
|
||||
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
|
||||
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len)
|
||||
) # (batch, seq_len, seq_len)
|
||||
|
||||
if memory is not None:
|
||||
memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim)
|
||||
# construct memory_attn_mask for cross-attn modules
|
||||
memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len)
|
||||
memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len)
|
||||
else:
|
||||
memory_attn_mask = None
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
x = mod(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
memory=memory,
|
||||
memory_attn_mask=memory_attn_mask,
|
||||
)
|
||||
|
||||
x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size)
|
||||
x = self.output_layer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Single decoder layer module.
|
||||
|
||||
Args:
|
||||
d_model: equal to decoder_dim, total dimension of the decoder
|
||||
attention_dim: total dimension of multi head attention
|
||||
num_heads: number of attention heads
|
||||
feedforward_dim: hidden dimension of feed_forward module
|
||||
dropout: dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 512,
|
||||
attention_dim: int = 512,
|
||||
num_heads: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(DecoderLayer, self).__init__()
|
||||
|
||||
self.norm_self_attn = nn.LayerNorm(d_model)
|
||||
self.self_attn = MultiHeadAttention(
|
||||
d_model, attention_dim, num_heads, dropout=0.0
|
||||
)
|
||||
|
||||
self.norm_src_attn = nn.LayerNorm(d_model)
|
||||
self.src_attn = MultiHeadAttention(
|
||||
d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0
|
||||
)
|
||||
|
||||
self.norm_ff = nn.LayerNorm(d_model)
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, feedforward_dim),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(feedforward_dim, d_model),
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
memory: Optional[torch.Tensor] = None,
|
||||
memory_attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: Input sequence of shape (seq_len, batch, embed_dim).
|
||||
attn_mask: A binary mask for self-attention module indicating which
|
||||
elements will be filled with -inf.
|
||||
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
|
||||
memory: Memory sequence of shape (seq_len, batch, memory_dim).
|
||||
memory_attn_mask: A binary mask for cross-attention module indicating which
|
||||
elements will be filled with -inf.
|
||||
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
|
||||
"""
|
||||
# self-attn module
|
||||
qkv = self.norm_self_attn(x)
|
||||
self_attn_out = self.self_attn(
|
||||
query=qkv, key=qkv, value=qkv, attn_mask=attn_mask
|
||||
)
|
||||
x = x + self.dropout(self_attn_out)
|
||||
|
||||
# cross-attn module
|
||||
q = self.norm_src_attn(x)
|
||||
src_attn_out = self.src_attn(
|
||||
query=q, key=memory, value=memory, attn_mask=memory_attn_mask
|
||||
)
|
||||
x = x + self.dropout(src_attn_out)
|
||||
|
||||
# feed-forward module
|
||||
x = x + self.dropout(self.feed_forward(self.norm_ff(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
attention_dim: dimension in the attention module, but must be a multiple of num_heads.
|
||||
num_heads: number of parallel attention heads.
|
||||
memory_dim: dimension of memory embedding, optional.
|
||||
dropout: a Dropout layer on attn_output_weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
attention_dim: int,
|
||||
num_heads: int,
|
||||
memory_dim: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.attention_dim = attention_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = attention_dim // num_heads
|
||||
assert self.head_dim * num_heads == attention_dim, (
|
||||
self.head_dim, num_heads, attention_dim
|
||||
)
|
||||
self.dropout = dropout
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
|
||||
self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||
self.linear_k = nn.Linear(
|
||||
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
|
||||
)
|
||||
self.linear_v = nn.Linear(
|
||||
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
|
||||
)
|
||||
|
||||
self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute dot product attention.
|
||||
|
||||
Args:
|
||||
query: Query tensor of shape (tgt_len, batch, embed_dim).
|
||||
key: Key tensor of shape (src_len, batch, embed_dim or memory_dim).
|
||||
value: Value tensor of shape (src_len, batch, embed_dim or memory_dim).
|
||||
key_padding_mask: A binary mask indicating which elements are padding.
|
||||
Its shape is (batch, src_len).
|
||||
attn_mask: A binary mask indicating which elements will be filled with -inf.
|
||||
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (tgt_len, batch, embed_dim).
|
||||
"""
|
||||
num_heads = self.num_heads
|
||||
head_dim = self.head_dim
|
||||
|
||||
tgt_len, batch, _ = query.shape
|
||||
src_len = key.shape[0]
|
||||
|
||||
q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim)
|
||||
k = self.linear_k(key) # (src_len, batch, num_heads * head_dim)
|
||||
v = self.linear_v(value) # (src_len, batch, num_heads * head_dim)
|
||||
|
||||
q = q.reshape(tgt_len, batch, num_heads, head_dim)
|
||||
q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim)
|
||||
k = k.reshape(src_len, batch, num_heads, head_dim)
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len)
|
||||
v = v.reshape(src_len, batch, num_heads, head_dim)
|
||||
v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
# Note: could remove the scaling operation when using ScaledAdam
|
||||
# (batch, head, tgt_len, src_len)
|
||||
attn_weights = torch.matmul(q, k) / math.sqrt(head_dim)
|
||||
|
||||
# From zipformer.py:
|
||||
# This is a harder way of limiting the attention scores to not be too large.
|
||||
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||
# this should be outside the normal range of the attention scores. We use
|
||||
# this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||
# gets very small gradients through the softmax can become very small, and
|
||||
# some mechanisms like that become ineffective.
|
||||
attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
|
||||
)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.shape == (batch, 1, src_len)
|
||||
or attn_mask.shape == (batch, tgt_len, src_len)
|
||||
), attn_mask.shape
|
||||
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf"))
|
||||
|
||||
attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
# (batch * head, tgt_len, head_dim)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape
|
||||
|
||||
attn_output = attn_output.transpose(0, 1).contiguous()
|
||||
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)
|
||||
|
||||
# (batch, tgt_len, embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Positional encoding.
|
||||
Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35.
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def subsequent_mask(size, device="cpu", dtype=torch.bool):
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
:param int size: size of mask
|
||||
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=dtype)
|
||||
return torch.tril(ret, out=ret)
|
||||
|
||||
|
||||
def _test_attention_decoder_model():
|
||||
m = AttentionDecoderModel(
|
||||
vocab_size=500,
|
||||
decoder_dim=512,
|
||||
num_decoder_layers=6,
|
||||
attention_dim=512,
|
||||
num_heads=8,
|
||||
feedforward_dim=2048,
|
||||
memory_dim=384,
|
||||
dropout=0.1,
|
||||
sos_id=1,
|
||||
eos_id=1,
|
||||
ignore_id=-1,
|
||||
)
|
||||
|
||||
num_param = sum([p.numel() for p in m.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
m.eval()
|
||||
encoder_out = torch.randn(2, 50, 384)
|
||||
encoder_out_lens = torch.full((2,), 50)
|
||||
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
|
||||
|
||||
nll = m.nll(encoder_out, encoder_out_lens, token_ids)
|
||||
print(nll)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_attention_decoder_model()
|
@ -73,6 +73,29 @@ Usage:
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method whole-lattice-rescoring
|
||||
|
||||
(6) attention-decoder-rescoring-no-ngram
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--use-attention-decoder 1 \
|
||||
--max-duration 100 \
|
||||
--decoding-method attention-decoder-rescoring-no-ngram
|
||||
|
||||
(7) attention-decoder-rescoring-with-ngram
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--use-attention-decoder 1 \
|
||||
--max-duration 100 \
|
||||
--hlg-scale 0.6 \
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method attention-decoder-rescoring-with-ngram
|
||||
"""
|
||||
|
||||
|
||||
@ -101,6 +124,8 @@ from icefall.decode import (
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder_no_ngram,
|
||||
rescore_with_attention_decoder_with_ngram,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
@ -212,6 +237,10 @@ def get_parser():
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||
lattice, rescore them with the attention decoder.
|
||||
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
||||
rescored lattice, rescore them with the attention decoder.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -406,6 +435,26 @@ def decode_one_batch(
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method == "attention-decoder-rescoring-no-ngram":
|
||||
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
attention_decoder=model.attention_decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
ans = dict()
|
||||
for a_scale_str, best_path in best_path_dict.items():
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
ans[a_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
if params.decoding_method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
@ -446,6 +495,7 @@ def decode_one_batch(
|
||||
assert params.decoding_method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
@ -466,6 +516,21 @@ def decode_one_batch(
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.decoding_method == "attention-decoder-rescoring-with-ngram":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
best_path_dict = rescore_with_attention_decoder_with_ngram(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
attention_decoder=model.attention_decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
||||
@ -564,12 +629,21 @@ def save_results(
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
if params.decoding_method in (
|
||||
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring"
|
||||
):
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
enable_log = True
|
||||
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
@ -577,8 +651,8 @@ def save_results(
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
@ -616,6 +690,8 @@ def main():
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"nbest-oracle",
|
||||
"attention-decoder-rescoring-no-ngram",
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -654,8 +730,10 @@ def main():
|
||||
params.vocab_size = num_classes
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = 0
|
||||
params.eos_id = 1
|
||||
params.sos_id = 1
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
@ -679,6 +757,7 @@ def main():
|
||||
if params.decoding_method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder-rescoring-with-ngram",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
@ -710,7 +789,9 @@ def main():
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method == "whole-lattice-rescoring":
|
||||
if params.decoding_method in [
|
||||
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram"
|
||||
]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
|
@ -404,6 +404,7 @@ def main():
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.sos_id = params.eos_id = token_table["<sos/eos>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
@ -466,8 +467,6 @@ def main():
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
|
109
egs/librispeech/ASR/zipformer/label_smoothing.py
Normal file
109
egs/librispeech/ASR/zipformer/label_smoothing.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LabelSmoothingLoss(torch.nn.Module):
|
||||
"""
|
||||
Implement the LabelSmoothingLoss proposed in the following paper
|
||||
https://arxiv.org/pdf/1512.00567.pdf
|
||||
(Rethinking the Inception Architecture for Computer Vision)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ignore_index: int = -1,
|
||||
label_smoothing: float = 0.1,
|
||||
reduction: str = "sum",
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
ignore_index:
|
||||
ignored class id
|
||||
label_smoothing:
|
||||
smoothing rate (0.0 means the conventional cross entropy loss)
|
||||
reduction:
|
||||
It has the same meaning as the reduction in
|
||||
`torch.nn.CrossEntropyLoss`. It can be one of the following three
|
||||
values: (1) "none": No reduction will be applied. (2) "mean": the
|
||||
mean of the output is taken. (3) "sum": the output will be summed.
|
||||
"""
|
||||
super().__init__()
|
||||
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
|
||||
assert reduction in ("none", "sum", "mean"), reduction
|
||||
self.ignore_index = ignore_index
|
||||
self.label_smoothing = label_smoothing
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss between x and target.
|
||||
|
||||
Args:
|
||||
x:
|
||||
prediction of dimension
|
||||
(batch_size, input_length, number_of_classes).
|
||||
target:
|
||||
target masked with self.ignore_index of
|
||||
dimension (batch_size, input_length).
|
||||
|
||||
Returns:
|
||||
A scalar tensor containing the loss without normalization.
|
||||
"""
|
||||
assert x.ndim == 3
|
||||
assert target.ndim == 2
|
||||
assert x.shape[:2] == target.shape
|
||||
num_classes = x.size(-1)
|
||||
x = x.reshape(-1, num_classes)
|
||||
# Now x is of shape (N*T, C)
|
||||
|
||||
# We don't want to change target in-place below,
|
||||
# so we make a copy of it here
|
||||
target = target.clone().reshape(-1)
|
||||
|
||||
ignored = target == self.ignore_index
|
||||
|
||||
# See https://github.com/k2-fsa/icefall/issues/240
|
||||
# and https://github.com/k2-fsa/icefall/issues/297
|
||||
# for why we don't use target[ignored] = 0 here
|
||||
target = torch.where(ignored, torch.zeros_like(target), target)
|
||||
|
||||
true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x)
|
||||
|
||||
true_dist = (
|
||||
true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes
|
||||
)
|
||||
|
||||
# Set the value of ignored indexes to 0
|
||||
#
|
||||
# See https://github.com/k2-fsa/icefall/issues/240
|
||||
# and https://github.com/k2-fsa/icefall/issues/297
|
||||
# for why we don't use true_dist[ignored] = 0 here
|
||||
true_dist = torch.where(
|
||||
ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
|
||||
torch.zeros_like(true_dist),
|
||||
true_dist,
|
||||
)
|
||||
|
||||
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
|
||||
if self.reduction == "sum":
|
||||
return loss.sum()
|
||||
elif self.reduction == "mean":
|
||||
return loss.sum() / (~ignored).sum()
|
||||
else:
|
||||
return loss.sum(dim=-1)
|
@ -34,11 +34,13 @@ class AsrModel(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
attention_decoder: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 384,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
use_attention_decoder: bool = False,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
@ -70,6 +72,8 @@ class AsrModel(nn.Module):
|
||||
Whether use transducer head. Default: True.
|
||||
use_ctc:
|
||||
Whether use CTC head. Default: False.
|
||||
use_attention_decoder:
|
||||
Whether use attention-decoder head. Default: False.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -111,6 +115,12 @@ class AsrModel(nn.Module):
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
self.use_attention_decoder = use_attention_decoder
|
||||
if use_attention_decoder:
|
||||
self.attention_decoder = attention_decoder
|
||||
else:
|
||||
assert attention_decoder is None
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -286,7 +296,7 @@ class AsrModel(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -308,7 +318,7 @@ class AsrModel(nn.Module):
|
||||
part
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
@ -322,6 +332,8 @@ class AsrModel(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||
|
||||
device = x.device
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
@ -333,7 +345,7 @@ class AsrModel(nn.Module):
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y.to(x.device),
|
||||
y=y.to(device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
@ -355,4 +367,14 @@ class AsrModel(nn.Module):
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss
|
||||
if self.use_attention_decoder:
|
||||
attention_decoder_loss = self.attention_decoder.calc_att_loss(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
ys=y.to(device),
|
||||
ys_lens=y_lens.to(device),
|
||||
)
|
||||
else:
|
||||
attention_decoder_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss
|
||||
|
@ -81,6 +81,15 @@ Usage of this script:
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(5) attention-decoder-rescoring-no-ngram
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--method attention-decoder-rescoring-no-ngram \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -100,6 +109,7 @@ from train import add_model_arguments, get_model, get_params
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder_no_ngram,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
@ -172,6 +182,8 @@ def get_parser():
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||
(4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||
lattice, rescore them with the attention decoder.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -276,6 +288,7 @@ def main():
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.sos_id = params.eos_id = token_table["<sos/eos>"]
|
||||
assert params.blank_id == 0
|
||||
|
||||
logging.info(f"{params}")
|
||||
@ -333,16 +346,13 @@ def main():
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
|
||||
max_token_id = params.vocab_size - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=H,
|
||||
@ -354,9 +364,23 @@ def main():
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
else:
|
||||
logging.info("Use attention decoder rescoring without ngram")
|
||||
best_path_dict = rescore_with_attention_decoder_no_ngram(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
attention_decoder=model.attention_decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
||||
elif params.method in [
|
||||
@ -430,7 +454,7 @@ def main():
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
if params.method == "ctc-decoding":
|
||||
if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = "".join(hyp)
|
||||
words = words.replace("▁", " ").strip()
|
||||
|
@ -48,6 +48,8 @@ It supports training with:
|
||||
- transducer loss (default), with `--use-transducer True --use-ctc False`
|
||||
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
|
||||
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
|
||||
- ctc loss & attention decoder loss, no transducer loss,
|
||||
with `--use-transducer False --use-ctc True --use-attention-decoder True`
|
||||
"""
|
||||
|
||||
|
||||
@ -66,6 +68,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from attention_decoder import AttentionDecoderModel
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
@ -221,6 +224,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="""Dimension used in the attention decoder""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-num-layers",
|
||||
type=int,
|
||||
default=6,
|
||||
help="""Number of transformer layers used in attention decoder""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-attention-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="""Attention dimension used in attention decoder""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-num-heads",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Number of attention heads used in attention decoder""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-feedforward-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="""Feedforward dimension used in attention decoder""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal",
|
||||
type=str2bool,
|
||||
@ -259,6 +297,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="If True, use CTC head.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-attention-decoder",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use attention-decoder head.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -404,6 +449,13 @@ def get_parser():
|
||||
help="Scale for CTC loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-loss-scale",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Scale for attention-decoder loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -528,6 +580,9 @@ def get_params() -> AttributeDict:
|
||||
# parameters for zipformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
# parameters for attention-decoder
|
||||
"ignore_id": -1,
|
||||
"label_smoothing": 0.1,
|
||||
"warm_step": 2000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
@ -600,6 +655,23 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
return joiner
|
||||
|
||||
|
||||
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = AttentionDecoderModel(
|
||||
vocab_size=params.vocab_size,
|
||||
decoder_dim=params.attention_decoder_dim,
|
||||
num_decoder_layers=params.attention_decoder_num_layers,
|
||||
attention_dim=params.attention_decoder_attention_dim,
|
||||
num_heads=params.attention_decoder_num_heads,
|
||||
feedforward_dim=params.attention_decoder_feedforward_dim,
|
||||
memory_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
ignore_id=params.ignore_id,
|
||||
label_smoothing=params.label_smoothing,
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
assert params.use_transducer or params.use_ctc, (
|
||||
f"At least one of them should be True, "
|
||||
@ -617,16 +689,23 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = None
|
||||
joiner = None
|
||||
|
||||
if params.use_attention_decoder:
|
||||
attention_decoder = get_attention_decoder_model(params)
|
||||
else:
|
||||
attention_decoder = None
|
||||
|
||||
model = AsrModel(
|
||||
encoder_embed=encoder_embed,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
attention_decoder=attention_decoder,
|
||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
decoder_dim=params.decoder_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
use_transducer=params.use_transducer,
|
||||
use_ctc=params.use_ctc,
|
||||
use_attention_decoder=params.use_attention_decoder,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -789,7 +868,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -819,6 +898,9 @@ def compute_loss(
|
||||
if params.use_ctc:
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
|
||||
if params.use_attention_decoder:
|
||||
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
@ -833,6 +915,8 @@ def compute_loss(
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if params.use_ctc:
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
if params.use_attention_decoder:
|
||||
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -1112,10 +1196,16 @@ def run(rank, world_size, args):
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.sos_id = params.eos_id = sp.piece_to_id("<sos/eos>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if not params.use_transducer:
|
||||
params.ctc_loss_scale = 1.0
|
||||
if not params.use_attention_decoder:
|
||||
params.ctc_loss_scale = 1.0
|
||||
else:
|
||||
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
|
||||
params.ctc_loss_scale, params.attention_decoder_loss_scale
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -43,6 +43,61 @@ Fine-tuned models, training logs, decoding logs, tensorboard and decoding result
|
||||
are available at
|
||||
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
||||
|
||||
### Multi Chinese datasets char-based training results (streaming) on zipformer large model
|
||||
|
||||
#### Streaming (with CTC head)
|
||||
|
||||
The training command for large model (num of params : ~160M):
|
||||
|
||||
Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features.
|
||||
|
||||
```
|
||||
./zipformer/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 20 \
|
||||
--use-fp16 1 \
|
||||
--max-duration 1200 \
|
||||
--num-workers 8 \
|
||||
--use-ctc 1 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
--causal 1 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 768,1024,1536,2048,1536,768 \
|
||||
--encoder-dim 256,384,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||
|
||||
```
|
||||
|
||||
The decoding command for transducer greedy search:
|
||||
|
||||
```
|
||||
./zipformer/decode.py \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--causal 1 \
|
||||
--use-averaged-model False \
|
||||
--chunk_size -1
|
||||
--left-context-frames -1 \
|
||||
--use-ctc 1 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
--max-duration 1200 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 768,1024,1536,2048,1536,768 \
|
||||
--encoder-dim 256,384,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||
```
|
||||
|
||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
|
||||
|
||||
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||
| CTC Greedy Streaming | 26.50 | 28.10| 1.71 | 1.97| 3.89| 4.06 | 17.23 | 3.69 | 2.87 | 8.14 | 3.61 |9.51 | 6.11 | 8.13 | 10.62 |
|
||||
| CTC Greedy Offline | 23.47 | 25.02 | 1.39 | 1.50 | 3.15 | 3.41 | 15.14 | 3.07 | 2.37 | 6.06 | 2.90 | 7.13 | 5.40 | 6.52 | 9.64 |
|
||||
| Transducer Greedy Offline | 23.16 | 24.78 | 1.33 | 1.38 | 3.06 | 3.23 | 15.36 | 2.54 | 2.09 | 5.24 | 2.28 | 6.26 | 4.87 | 6.26 | 7.07 |
|
||||
| Transducer Greedy Streaming | 26.83|28.74 | 1.75 | 1.91 | 3.84 | 4.12 | 17.83 | 3.23 | 2.71 | 7.31 | 3.16 | 8.69 | 5.71 | 7.91 | 8.54 |
|
||||
|
||||
Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-large
|
||||
|
||||
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
||||
|
||||
|
@ -1,247 +0,0 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import lhotse
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
class MultiDataset:
|
||||
def __init__(self, fbank_dir: str):
|
||||
"""
|
||||
Args:
|
||||
manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
- aishell_cuts_train.jsonl.gz
|
||||
- aishell2_cuts_train.jsonl.gz
|
||||
- aishell4_cuts_train_L.jsonl.gz
|
||||
- aishell4_cuts_train_M.jsonl.gz
|
||||
- aishell4_cuts_train_S.jsonl.gz
|
||||
- alimeeting-far_cuts_train.jsonl.gz
|
||||
- magicdata_cuts_train.jsonl.gz
|
||||
- primewords_cuts_train.jsonl.gz
|
||||
- stcmds_cuts_train.jsonl.gz
|
||||
- thchs_30_cuts_train.jsonl.gz
|
||||
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
||||
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
||||
- wenetspeech/cuts_L_fixed.jsonl.gz
|
||||
"""
|
||||
self.fbank_dir = Path(fbank_dir)
|
||||
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# THCHS-30
|
||||
logging.info("Loading THCHS-30 in lazy mode")
|
||||
thchs_30_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-1
|
||||
logging.info("Loading Aishell-1 in lazy mode")
|
||||
aishell_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 in lazy mode")
|
||||
aishell_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-4
|
||||
logging.info("Loading Aishell-4 in lazy mode")
|
||||
aishell_4_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
|
||||
)
|
||||
aishell_4_M_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
|
||||
)
|
||||
aishell_4_S_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
|
||||
)
|
||||
|
||||
# ST-CMDS
|
||||
logging.info("Loading ST-CMDS in lazy mode")
|
||||
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
|
||||
|
||||
# Primewords
|
||||
logging.info("Loading Primewords in lazy mode")
|
||||
primewords_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData in lazy mode")
|
||||
magicdata_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting in lazy mode")
|
||||
alimeeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech in lazy mode")
|
||||
wenetspeech_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech in lazy mode")
|
||||
kespeech_1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
thchs_30_cuts,
|
||||
aishell_cuts,
|
||||
aishell_2_cuts,
|
||||
aishell_4_L_cuts,
|
||||
aishell_4_M_cuts,
|
||||
aishell_4_S_cuts,
|
||||
alimeeting_cuts,
|
||||
stcmds_cuts,
|
||||
primewords_cuts,
|
||||
magicdata_cuts,
|
||||
wenetspeech_L_cuts,
|
||||
kespeech_1_cuts,
|
||||
kespeech_2_cuts,
|
||||
weights=[
|
||||
len(thchs_30_cuts),
|
||||
len(aishell_cuts),
|
||||
len(aishell_2_cuts),
|
||||
len(aishell_4_L_cuts),
|
||||
len(aishell_4_M_cuts),
|
||||
len(aishell_4_S_cuts),
|
||||
len(alimeeting_cuts),
|
||||
len(stcmds_cuts),
|
||||
len(primewords_cuts),
|
||||
len(magicdata_cuts),
|
||||
len(wenetspeech_L_cuts),
|
||||
len(kespeech_1_cuts),
|
||||
len(kespeech_2_cuts),
|
||||
],
|
||||
)
|
||||
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech DEV set in lazy mode")
|
||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
|
||||
)
|
||||
|
||||
return wenetspeech_dev_cuts
|
||||
|
||||
def test_cuts(self) -> Dict[str, CutSet]:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell set in lazy mode")
|
||||
aishell_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||
)
|
||||
aishell_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
)
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-4
|
||||
logging.info("Loading Aishell-4 TEST set in lazy mode")
|
||||
aishell4_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting set in lazy mode")
|
||||
alimeeting_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
|
||||
)
|
||||
alimeeting_eval_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData set in lazy mode")
|
||||
magicdata_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
|
||||
)
|
||||
magicdata_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech set in lazy mode")
|
||||
kespeech_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech set in lazy mode")
|
||||
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
||||
)
|
||||
wenetspeech_test_net_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
||||
)
|
||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
|
||||
)
|
||||
|
||||
return {
|
||||
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
||||
# "aishell_test": aishell_test_cuts,
|
||||
# "aishell_dev": aishell_dev_cuts,
|
||||
# "ali-meeting_test": alimeeting_test_cuts,
|
||||
# "ali-meeting_eval": alimeeting_eval_cuts,
|
||||
# "aishell-4_test": aishell4_test_cuts,
|
||||
# "aishell-2_test": aishell2_test_cuts,
|
||||
# "aishell-2_dev": aishell2_dev_cuts,
|
||||
# "magicdata_test": magicdata_test_cuts,
|
||||
# "magicdata_dev": magicdata_dev_cuts,
|
||||
# "kespeech-asr_test": kespeech_test_cuts,
|
||||
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
|
||||
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
||||
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
||||
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
||||
}
|
1
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Symbolic link
1
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py
|
@ -46,7 +46,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -367,21 +367,18 @@ def decode_dataset(
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = list(ref_text.replace(" ", ""))
|
||||
hyp_words = list("".join(hyp_words))
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
ref_text = normalize_text_alimeeting(ref_text)
|
||||
hyp_text = "".join(hyp_words)
|
||||
this_batch.append((cut_id, ref_text, hyp_text))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
@ -583,7 +580,7 @@ def main():
|
||||
data_module = AsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args.manifest_dir)
|
||||
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()}
|
||||
|
||||
def remove_short_utt(c: Cut):
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
|
@ -118,7 +118,7 @@ from beam_search import (
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -532,7 +532,6 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
@ -548,6 +547,7 @@ def decode_dataset(
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_text = normalize_text_alimeeting(ref_text)
|
||||
hyp_text = "".join(hyp_words)
|
||||
this_batch.append((cut_id, ref_text, hyp_text))
|
||||
|
||||
@ -795,7 +795,7 @@ def main():
|
||||
)
|
||||
return T > 0
|
||||
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()}
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dl = [
|
||||
|
@ -1,316 +0,0 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import lhotse
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
class MultiDataset:
|
||||
def __init__(self, fbank_dir: str):
|
||||
"""
|
||||
Args:
|
||||
manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
- aidatatang_cuts_train.jsonl.gz
|
||||
- aishell_cuts_train.jsonl.gz
|
||||
- aishell2_cuts_train.jsonl.gz
|
||||
- aishell4_cuts_train_L.jsonl.gz
|
||||
- aishell4_cuts_train_M.jsonl.gz
|
||||
- aishell4_cuts_train_S.jsonl.gz
|
||||
- alimeeting-far_cuts_train.jsonl.gz
|
||||
- magicdata_cuts_train.jsonl.gz
|
||||
- primewords_cuts_train.jsonl.gz
|
||||
- stcmds_cuts_train.jsonl.gz
|
||||
- thchs_30_cuts_train.jsonl.gz
|
||||
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
||||
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
||||
- wenetspeech/cuts_L.jsonl.gz
|
||||
"""
|
||||
self.fbank_dir = Path(fbank_dir)
|
||||
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# THCHS-30
|
||||
logging.info("Loading THCHS-30 in lazy mode")
|
||||
thchs_30_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-1
|
||||
logging.info("Loading Aishell-1 in lazy mode")
|
||||
aishell_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 in lazy mode")
|
||||
aishell_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-4
|
||||
logging.info("Loading Aishell-4 in lazy mode")
|
||||
aishell_4_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
|
||||
)
|
||||
aishell_4_M_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
|
||||
)
|
||||
aishell_4_S_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
|
||||
)
|
||||
|
||||
# ST-CMDS
|
||||
logging.info("Loading ST-CMDS in lazy mode")
|
||||
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
|
||||
|
||||
# Primewords
|
||||
logging.info("Loading Primewords in lazy mode")
|
||||
primewords_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData in lazy mode")
|
||||
magicdata_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# Aidatatang_200zh
|
||||
logging.info("Loading Aidatatang_200zh in lazy mode")
|
||||
aidatatang_200zh_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aidatatang_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting in lazy mode")
|
||||
alimeeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech in lazy mode")
|
||||
wenetspeech_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech in lazy mode")
|
||||
kespeech_1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
thchs_30_cuts,
|
||||
aishell_cuts,
|
||||
aishell_2_cuts,
|
||||
aishell_4_L_cuts,
|
||||
aishell_4_M_cuts,
|
||||
aishell_4_S_cuts,
|
||||
stcmds_cuts,
|
||||
primewords_cuts,
|
||||
magicdata_cuts,
|
||||
aidatatang_200zh_cuts,
|
||||
alimeeting_cuts,
|
||||
wenetspeech_L_cuts,
|
||||
kespeech_1_cuts,
|
||||
kespeech_2_cuts,
|
||||
weights=[
|
||||
len(thchs_30_cuts),
|
||||
len(aishell_cuts),
|
||||
len(aishell_2_cuts),
|
||||
len(aishell_4_L_cuts),
|
||||
len(aishell_4_M_cuts),
|
||||
len(aishell_4_S_cuts),
|
||||
len(stcmds_cuts),
|
||||
len(primewords_cuts),
|
||||
len(magicdata_cuts),
|
||||
len(aidatatang_200zh_cuts),
|
||||
len(alimeeting_cuts),
|
||||
len(wenetspeech_L_cuts),
|
||||
len(kespeech_1_cuts),
|
||||
len(kespeech_2_cuts),
|
||||
],
|
||||
)
|
||||
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
||||
# Aidatatang_200zh
|
||||
logging.info("Loading Aidatatang_200zh DEV set in lazy mode")
|
||||
aidatatang_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell DEV set in lazy mode")
|
||||
aishell_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 DEV set in lazy mode")
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting DEV set in lazy mode")
|
||||
alimeeting_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData DEV set in lazy mode")
|
||||
magicdata_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech DEV set in lazy mode")
|
||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech DEV set in lazy mode")
|
||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
||||
)
|
||||
|
||||
return wenetspeech_dev_cuts
|
||||
# return [
|
||||
# aidatatang_dev_cuts,
|
||||
# aishell_dev_cuts,
|
||||
# aishell2_dev_cuts,
|
||||
# alimeeting_dev_cuts,
|
||||
# magicdata_dev_cuts,
|
||||
# kespeech_dev_phase1_cuts,
|
||||
# kespeech_dev_phase2_cuts,
|
||||
# wenetspeech_dev_cuts,
|
||||
# ]
|
||||
|
||||
def test_cuts(self) -> Dict[str, CutSet]:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# Aidatatang_200zh
|
||||
logging.info("Loading Aidatatang_200zh set in lazy mode")
|
||||
aidatatang_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aidatatang_cuts_test.jsonl.gz"
|
||||
)
|
||||
aidatatang_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell set in lazy mode")
|
||||
aishell_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
||||
)
|
||||
aishell_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
)
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-4
|
||||
logging.info("Loading Aishell-4 TEST set in lazy mode")
|
||||
aishell4_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting set in lazy mode")
|
||||
alimeeting_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
|
||||
)
|
||||
alimeeting_eval_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData set in lazy mode")
|
||||
magicdata_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
|
||||
)
|
||||
magicdata_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech set in lazy mode")
|
||||
kespeech_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech set in lazy mode")
|
||||
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
||||
)
|
||||
wenetspeech_test_net_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
||||
)
|
||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
||||
)
|
||||
|
||||
return {
|
||||
"aidatatang_test": aidatatang_test_cuts,
|
||||
"aidatatang_dev": aidatatang_dev_cuts,
|
||||
"alimeeting_test": alimeeting_test_cuts,
|
||||
"alimeeting_eval": alimeeting_eval_cuts,
|
||||
"aishell_test": aishell_test_cuts,
|
||||
"aishell_dev": aishell_dev_cuts,
|
||||
"aishell-2_test": aishell2_test_cuts,
|
||||
"aishell-2_dev": aishell2_dev_cuts,
|
||||
"aishell-4": aishell4_test_cuts,
|
||||
"magicdata_test": magicdata_test_cuts,
|
||||
"magicdata_dev": magicdata_dev_cuts,
|
||||
"kespeech-asr_test": kespeech_test_cuts,
|
||||
"kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
|
||||
"kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
||||
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
||||
"wenetspeech-net_test": wenetspeech_test_net_cuts,
|
||||
"wenetspeech_dev": wenetspeech_dev_cuts,
|
||||
}
|
1
egs/multi_zh-hans/ASR/zipformer/multi_dataset.py
Symbolic link
1
egs/multi_zh-hans/ASR/zipformer/multi_dataset.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py
|
@ -539,6 +539,43 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
if normalize == "none":
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
|
||||
|
||||
def _to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
@ -788,6 +825,9 @@ def compute_loss(
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in texts
|
||||
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
|
@ -114,7 +114,8 @@ def extract_hyp_ref_wavname(filename):
|
||||
for line in f:
|
||||
if "ref" in line:
|
||||
ref = line.split("ref=")[1].strip()
|
||||
ref = ref[2:-2]
|
||||
if ref[0] == "[":
|
||||
ref = ref[2:-2]
|
||||
list_elements = ref.split("', '")
|
||||
ref = "".join(list_elements)
|
||||
refs.append(ref)
|
||||
|
@ -1083,6 +1083,238 @@ def rescore_with_attention_decoder(
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_attention_decoder_with_ngram(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
attention_decoder: torch.nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
nbest_scale: float = 1.0,
|
||||
ngram_lm_scale: Optional[float] = None,
|
||||
attention_scale: Optional[float] = None,
|
||||
use_double_scores: bool = True,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest score is
|
||||
the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc].
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
attention_decoder:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
encoder_out:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `(N, T, C)`.
|
||||
encoder_out_lens:
|
||||
Length of encoder outputs, with shape of `(N,)`.
|
||||
nbest_scale:
|
||||
It's the scale applied to `lattice.scores`. A smaller value
|
||||
leads to more unique paths at the risk of missing the correct path.
|
||||
ngram_lm_scale:
|
||||
Optional. It specifies the scale for n-gram LM scores.
|
||||
attention_scale:
|
||||
Optional. It specifies the scale for attention decoder scores.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each utterance in the lattice.
|
||||
"""
|
||||
max_loop_count = 10
|
||||
loop_count = 0
|
||||
while loop_count <= max_loop_count:
|
||||
try:
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
nbest = nbest.intersect(lattice)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logging.info(f"Caught exception:\n{e}\n")
|
||||
logging.info(f"num_paths before decreasing: {num_paths}")
|
||||
num_paths = int(num_paths / 2)
|
||||
if loop_count >= max_loop_count or num_paths <= 0:
|
||||
logging.info("Return None as the resulting lattice is too large.")
|
||||
return None
|
||||
logging.info(
|
||||
"This OOM is not an error. You can ignore it. "
|
||||
"If your model does not converge well, or --max-duration "
|
||||
"is too large, or the input sound file is difficult to "
|
||||
"decode, you will meet this exception."
|
||||
)
|
||||
logging.info(f"num_paths after decreasing: {num_paths}")
|
||||
loop_count += 1
|
||||
|
||||
# Now nbest.fsa has its scores set.
|
||||
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
am_scores = nbest.compute_am_scores()
|
||||
ngram_lm_scores = nbest.compute_lm_scores()
|
||||
|
||||
# The `tokens` attribute is set inside `compile_hlg.py`
|
||||
assert hasattr(nbest.fsa, "tokens")
|
||||
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||
|
||||
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||
# the shape of memory is (T, N, C), so we use axis=1 here
|
||||
expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
|
||||
expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
|
||||
|
||||
# remove axis corresponding to states.
|
||||
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||
tokens = tokens.remove_values_leq(0)
|
||||
token_ids = tokens.tolist()
|
||||
|
||||
nll = attention_decoder.nll(
|
||||
encoder_out=expanded_encoder_out,
|
||||
encoder_out_lens=expanded_encoder_out_lens,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == len(token_ids)
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
ans = dict()
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
for a_scale in attention_scale_list:
|
||||
tot_scores = (
|
||||
am_scores.values
|
||||
+ n_scale * ngram_lm_scores.values
|
||||
+ a_scale * attention_scores
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_attention_decoder_no_ngram(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
attention_decoder: torch.nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
nbest_scale: float = 1.0,
|
||||
attention_scale: Optional[float] = None,
|
||||
use_double_scores: bool = True,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest score is
|
||||
the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc].
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
attention_decoder:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
encoder_out:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `(N, T, C)`.
|
||||
encoder_out_lens:
|
||||
Length of encoder outputs, with shape of `(N,)`.
|
||||
nbest_scale:
|
||||
It's the scale applied to `lattice.scores`. A smaller value
|
||||
leads to more unique paths at the risk of missing the correct path.
|
||||
attention_scale:
|
||||
Optional. It specifies the scale for attention decoder scores.
|
||||
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each utterance in the lattice.
|
||||
"""
|
||||
# path is a ragged tensor with dtype torch.int32.
|
||||
# It has three axes [utt][path][arc_pos]
|
||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||
# Note that labels, aux_labels and scores contains 0s and -1s.
|
||||
# The last entry in each sublist is -1.
|
||||
# The axes are [path][token_id]
|
||||
labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0)
|
||||
aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0)
|
||||
scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0)
|
||||
|
||||
# Remove -1 from labels as we will use it to construct a linear FSA
|
||||
labels = labels.remove_values_eq(-1)
|
||||
fsa = k2.linear_fsa(labels)
|
||||
fsa.aux_labels = aux_labels.values
|
||||
|
||||
# utt_to_path_shape has axes [utt][path]
|
||||
utt_to_path_shape = path.shape.get_layer(0)
|
||||
scores = k2.RaggedTensor(utt_to_path_shape, scores.sum())
|
||||
|
||||
path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long)
|
||||
# the shape of memory is (N, T, C), so we use axis=0 here
|
||||
expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
|
||||
expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
|
||||
|
||||
token_ids = aux_labels.remove_values_leq(0).tolist()
|
||||
|
||||
nll = attention_decoder.nll(
|
||||
encoder_out=expanded_encoder_out,
|
||||
encoder_out_lens=expanded_encoder_out_lens,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == len(token_ids)
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
ans = dict()
|
||||
|
||||
for a_scale in attention_scale_list:
|
||||
tot_scores = scores.values + a_scale * attention_scores
|
||||
ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(fsa, max_indexes)
|
||||
|
||||
key = f"attention_scale_{a_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_rnn_lm(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
|
@ -28,5 +28,6 @@ multi_quantization
|
||||
onnx
|
||||
onnxmltools
|
||||
onnxruntime
|
||||
onnxconverter_common
|
||||
kaldifst
|
||||
kaldi-decoder
|
||||
|
Loading…
x
Reference in New Issue
Block a user