mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
WIP: Begin to add BPE decoding
This commit is contained in:
parent
d3101fb005
commit
4ccae509d3
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,3 +2,5 @@ data
|
||||
__pycache__
|
||||
path.sh
|
||||
exp
|
||||
exp*/
|
||||
*.pt
|
||||
|
0
egs/librispeech/ASR/conformer_ctc/__init__.py
Normal file
0
egs/librispeech/ASR/conformer_ctc/__init__.py
Normal file
914
egs/librispeech/ASR/conformer_ctc/conformer.py
Normal file
914
egs/librispeech/ASR/conformer_ctc/conformer.py
Normal file
@ -0,0 +1,914 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
# Apache 2.0
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
num_classes (int): Number of output classes
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): attention dimension
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
num_decoder_layers (int): number of decoder layers
|
||||
dropout (float): dropout rate
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
normalize_before (bool): whether to use layer_norm before the first block.
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
num_decoder_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
is_espnet_structure: bool = False,
|
||||
mmi_loss: bool = True,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=subsampling_factor,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
mmi_loss=mmi_loss,
|
||||
use_feat_batchnorm=use_feat_batchnorm,
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
normalize_before,
|
||||
is_espnet_structure,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
self.normalize_before = normalize_before
|
||||
self.is_espnet_structure = is_espnet_structure
|
||||
if self.normalize_before and self.is_espnet_structure:
|
||||
self.after_norm = nn.LayerNorm(d_model)
|
||||
else:
|
||||
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||
# and throws an error without this change.
|
||||
self.after_norm = identity
|
||||
|
||||
def encode(
|
||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x: Tensor of dimension (batch_size, num_features, input_length).
|
||||
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
|
||||
|
||||
Returns:
|
||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||
"""
|
||||
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
|
||||
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||
if mask is not None:
|
||||
mask = mask.to(x.device)
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
|
||||
if self.normalize_before and self.is_espnet_structure:
|
||||
x = self.after_norm(x)
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
normalize_before: whether to use layer_norm before the first block.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = encoder_layer(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
is_espnet_structure: bool = False,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
|
||||
)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the macaron style FNN module
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||
|
||||
self.ff_scale = 0.5
|
||||
|
||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the final output of the block
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
|
||||
# macaron style feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
src = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(src)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = residual + self.dropout(src_att)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
|
||||
# convolution module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
src = residual + self.dropout(self.conv_module(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
|
||||
# feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
|
||||
if self.normalize_before:
|
||||
src = self.norm_final(src)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class ConformerEncoder(nn.TransformerEncoder):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
|
||||
) -> None:
|
||||
super(ConformerEncoder, self).__init__(
|
||||
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
output = src
|
||||
|
||||
for mod in self.layers:
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
||||
|
||||
Args:
|
||||
d_model: Embedding dimension.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||
) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
+ 1 : self.pe.size(1) // 2
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
r"""Multi-Head Attention layer with relative position encoding
|
||||
|
||||
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
is_espnet_structure: bool = False,
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.is_espnet_structure = is_espnet_structure
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
return self.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
pos_emb,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj.weight,
|
||||
self.in_proj.bias,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
(note: time2 has the same value as time1, but it is for
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor,
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||
length, N is the batch size, E is the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = nn.functional.linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
).chunk(3, dim=-1)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = nn.functional.linear(key, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = nn.functional.linear(value, _w, _b)
|
||||
|
||||
if not self.is_espnet_structure:
|
||||
q = q * scaling
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.dtype == torch.float32
|
||||
or attn_mask.dtype == torch.float64
|
||||
or attn_mask.dtype == torch.float16
|
||||
or attn_mask.dtype == torch.uint8
|
||||
or attn_mask.dtype == torch.bool
|
||||
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
||||
attn_mask.dtype
|
||||
)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError(
|
||||
"The size of the 2D attn_mask is not correct."
|
||||
)
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0),
|
||||
key.size(0),
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"The size of the 3D attn_mask is not correct."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"attn_mask's dimension {} is not supported".format(
|
||||
attn_mask.dim()
|
||||
)
|
||||
)
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if (
|
||||
key_padding_mask is not None
|
||||
and key_padding_mask.dtype == torch.uint8
|
||||
):
|
||||
warnings.warn(
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
src_len = k.size(0)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
key_padding_mask.size(0), bsz
|
||||
)
|
||||
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
||||
key_padding_mask.size(1), src_len
|
||||
)
|
||||
|
||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
if not self.is_espnet_structure:
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = Swish()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
129
egs/librispeech/ASR/conformer_ctc/decode.py
Executable file
129
egs/librispeech/ASR/conformer_ctc/decode.py
Executable file
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||
|
||||
# (still working in progress)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang/bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"attention_dim": 512,
|
||||
"num_classes": 5000,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"is_espnet_structure": True,
|
||||
"mmi_loss": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"search_beam": 20,
|
||||
"output_beam": 5,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
# Possible values for method:
|
||||
# - 1best
|
||||
# - nbest
|
||||
# - nbest-rescoring
|
||||
# - whole-lattice-rescoring
|
||||
"method": "whole-lattice-rescoring",
|
||||
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
||||
"num_paths": 30,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=9,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=params.num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
is_espnet_structure=params.is_espnet_structure,
|
||||
mmi_loss=params.mmi_loss,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
token_ids_with_blank = list(range(params.num_classes))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1100
egs/librispeech/ASR/conformer_ctc/transformer.py
Normal file
1100
egs/librispeech/ASR/conformer_ctc/transformer.py
Normal file
File diff suppressed because it is too large
Load Diff
0
egs/librispeech/ASR/local/__init__.py
Normal file
0
egs/librispeech/ASR/local/__init__.py
Normal file
@ -3,28 +3,48 @@
|
||||
"""
|
||||
This script compiles HLG from
|
||||
|
||||
- H, the ctc topology, built from phones contained in data/lang/lexicon.txt
|
||||
- L, the lexicon, built from data/lang/L_disambig.pt
|
||||
- H, the ctc topology, built from phones contained in lexicon.txt
|
||||
- L, the lexicon, built from L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in data/lm/HLG.pt
|
||||
The generated HLG is saved in data/lm/HLG.pt (phone based)
|
||||
or data/lm/HLG_bpe.pt (BPE based)
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def main():
|
||||
lexicon = Lexicon("data/lang")
|
||||
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang or data/lang/bpe.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
print(f"building ctc_top. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load("data/lang/L_disambig.pt"))
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
print("Loading pre-compiled G_3_gram")
|
||||
d = torch.load("data/lm/G_3_gram.pt")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
else:
|
||||
print("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), "G_3_gram.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.phones["#0"]
|
||||
first_word_disambig_id = lexicon.words["#0"]
|
||||
@ -74,9 +94,34 @@ def main():
|
||||
print("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def phone_based_HLG():
|
||||
if Path("data/lm/HLG.pt").is_file():
|
||||
return
|
||||
|
||||
print("Compiling phone based HLG")
|
||||
HLG = compile_HLG("data/lang")
|
||||
|
||||
print("Saving HLG.pt to data/lm")
|
||||
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
|
||||
|
||||
|
||||
def bpe_based_HLG():
|
||||
if Path("data/lm/HLG_bpe.pt").is_file():
|
||||
return
|
||||
|
||||
print("Compiling BPE based HLG")
|
||||
HLG = compile_HLG("data/lang/bpe")
|
||||
print("Saving HLG.pt to data/lm")
|
||||
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
|
||||
|
||||
|
||||
def main():
|
||||
phone_based_HLG()
|
||||
bpe_based_HLG()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -48,7 +48,7 @@ def read_lexicon(filename: str) -> Lexicon:
|
||||
"""
|
||||
ans = []
|
||||
|
||||
with open(filename, "r", encoding="latin-1") as f:
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
whitespace = re.compile("[ \t]+")
|
||||
for line in f:
|
||||
a = whitespace.split(line.strip(" \t\r\n"))
|
||||
@ -80,7 +80,7 @@ def write_lexicon(filename: str, lexicon: Lexicon) -> None:
|
||||
lexicon:
|
||||
It can be the return value of :func:`read_lexicon`.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for word, prons in lexicon:
|
||||
f.write(f"{word} {' '.join(prons)}\n")
|
||||
|
||||
@ -100,7 +100,7 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
@ -151,7 +151,8 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||
Return a tuple with two elements:
|
||||
|
||||
- The output lexicon with disambiguation symbols
|
||||
- The ID of the max disambiguation symbols
|
||||
- The ID of the max disambiguation symbols that appears
|
||||
in the lexicon
|
||||
"""
|
||||
|
||||
# (1) Work out the count of each phone-sequence in the
|
||||
|
196
egs/librispeech/ASR/local/prepare_lang_bpe.py
Executable file
196
egs/librispeech/ASR/local/prepare_lang_bpe.py
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script takes as inputs the following files:
|
||||
- data/lang/bpe/bpe.model,
|
||||
- data/lang/bpe/tokens.txt (will remove it),
|
||||
- data/lang/bpe/words.txt
|
||||
|
||||
and generates the following files in the directory data/lang/bpe:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- phones.txt
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
)
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go.
|
||||
|
||||
arcs = []
|
||||
|
||||
assert token2id["<blank>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, prons in lexicon:
|
||||
assert len(prons) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
prons = [token2id[i] for i in prons]
|
||||
|
||||
for i in range(len(prons) - 1):
|
||||
if i == 0:
|
||||
arcs.append([cur_state, next_state, prons[i], word, 0])
|
||||
else:
|
||||
arcs.append([cur_state, next_state, prons[i], eps, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last phone of this word
|
||||
i = len(prons) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, prons[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_phone = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_phone=disambig_phone,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def generate_lexicon(model_file: str, words: List[str]) -> Lexicon:
|
||||
"""Generate a lexicon from a BPE model.
|
||||
|
||||
Args:
|
||||
model_file:
|
||||
Path to a sentencepiece model.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
Returns:
|
||||
Return a dict whose keys are words and values are the corresponding
|
||||
word pieces.
|
||||
"""
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(model_file))
|
||||
|
||||
words_pieces: List[List[str]] = sp.encode(words, out_type=str)
|
||||
|
||||
lexicon = []
|
||||
for word, pieces in zip(words, words_pieces):
|
||||
lexicon.append((word, pieces))
|
||||
|
||||
lexicon.append(("<UNK>", ["<UNK>"]))
|
||||
return lexicon
|
||||
|
||||
|
||||
def main():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
model_file = lang_dir / "bpe.model"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
lexicon = generate_lexicon(model_file, words)
|
||||
|
||||
# TODO(fangjun): Remove tokens.txt and generate it from the model directly.
|
||||
#
|
||||
# We are using it since the IDs we are using in tokens.txt is
|
||||
# different from the one contained in the model
|
||||
token_sym_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table.add(f"#{i}")
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
token_sym_table.to_file(lang_dir / "phones.txt")
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if False:
|
||||
# Just for debugging, will remove it
|
||||
L.labels_sym = k2.SymbolTable.from_file(lang_dir / "phones.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
L_disambig.labels_sym = L.labels_sym
|
||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||
L.draw(lang_dir / "L.svg", title="L")
|
||||
L_disambig.draw(lang_dir / "L_disambig.svg", title="L_disambig")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -89,7 +89,17 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
echo "Stage 6: Prepare G"
|
||||
echo "State 6: Prepare BPE based lang"
|
||||
mkdir -p data/lang/bpe
|
||||
cp data/lang/words.txt data/lang/bpe/
|
||||
|
||||
if [ ! -f data/lang/bpe/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
echo "Stage 7: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
@ -112,9 +122,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
echo "Stage 7: Compile HLG"
|
||||
if [ ! -f data/lm/HLG.pt ]; then
|
||||
python3 ./local/compile_hlg.py
|
||||
fi
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
echo "Stage 8: Compile HLG"
|
||||
python3 ./local/compile_hlg.py
|
||||
fi
|
||||
|
3
requirements.txt
Normal file
3
requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
kaldilm
|
||||
kaldialign
|
||||
sentencepiece>=0.1.96
|
Loading…
x
Reference in New Issue
Block a user