WIP: Begin to add BPE decoding

This commit is contained in:
Fangjun Kuang 2021-07-26 20:06:58 +08:00
parent d3101fb005
commit 4ccae509d3
11 changed files with 2416 additions and 18 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@ data
__pycache__
path.sh
exp
exp*/
*.pt

View File

@ -0,0 +1,914 @@
#!/usr/bin/env python3
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
import math
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):
"""
Args:
num_features (int): Number of input features
num_classes (int): Number of output classes
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
num_decoder_layers: int = 6,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
is_espnet_structure: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
num_classes=num_classes,
subsampling_factor=subsampling_factor,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
mmi_loss=mmi_loss,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
cnn_module_kernel,
normalize_before,
is_espnet_structure,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def encode(
self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x: Tensor of dimension (batch_size, num_features, input_length).
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
mask = encoder_padding_mask(x.size(0), supervisions)
if mask is not None:
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before and self.is_espnet_structure:
x = self.after_norm(x)
return x, mask
class ConformerEncoderLayer(nn.Module):
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
is_espnet_structure: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src
class ConformerEncoder(nn.TransformerEncoder):
r"""ConformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
) -> None:
super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
for mod in self.layers:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
if self.norm is not None:
output = self.norm(output)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_espnet_structure: bool = False,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
self.is_espnet_structure = is_espnet_structure
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
query,
key,
value,
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if not self.is_espnet_structure:
q = q * scaling
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
key_padding_mask.size(1), src_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
if not self.is_espnet_structure:
attn_output_weights = (
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
else:
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def identity(x):
return x

View File

@ -0,0 +1,129 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
# (still working in progress)
import argparse
import logging
from pathlib import Path
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"num_classes": 5000,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"search_beam": 20,
"output_beam": 5,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
# Possible values for method:
# - 1best
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
"method": "whole-lattice-rescoring",
# num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 30,
}
)
return params
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=9,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
return parser
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
token_ids_with_blank = list(range(params.num_classes))
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

View File

@ -3,28 +3,48 @@
"""
This script compiles HLG from
- H, the ctc topology, built from phones contained in data/lang/lexicon.txt
- L, the lexicon, built from data/lang/L_disambig.pt
- H, the ctc topology, built from phones contained in lexicon.txt
- L, the lexicon, built from L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated HLG is saved in data/lm/HLG.pt
The generated HLG is saved in data/lm/HLG.pt (phone based)
or data/lm/HLG_bpe.pt (BPE based)
"""
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def main():
lexicon = Lexicon("data/lang")
def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang or data/lang/bpe.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
print(f"building ctc_top. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load("data/lang/L_disambig.pt"))
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
print("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d).to(device)
else:
print("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "G_3_gram.pt")
first_token_disambig_id = lexicon.phones["#0"]
first_word_disambig_id = lexicon.words["#0"]
@ -74,9 +94,34 @@ def main():
print("Arc sorting LG")
HLG = k2.arc_sort(HLG)
return HLG
def phone_based_HLG():
if Path("data/lm/HLG.pt").is_file():
return
print("Compiling phone based HLG")
HLG = compile_HLG("data/lang")
print("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
def bpe_based_HLG():
if Path("data/lm/HLG_bpe.pt").is_file():
return
print("Compiling BPE based HLG")
HLG = compile_HLG("data/lang/bpe")
print("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
def main():
phone_based_HLG()
bpe_based_HLG()
if __name__ == "__main__":
main()

View File

@ -48,7 +48,7 @@ def read_lexicon(filename: str) -> Lexicon:
"""
ans = []
with open(filename, "r", encoding="latin-1") as f:
with open(filename, "r", encoding="utf-8") as f:
whitespace = re.compile("[ \t]+")
for line in f:
a = whitespace.split(line.strip(" \t\r\n"))
@ -80,7 +80,7 @@ def write_lexicon(filename: str, lexicon: Lexicon) -> None:
lexicon:
It can be the return value of :func:`read_lexicon`.
"""
with open(filename, "w") as f:
with open(filename, "w", encoding="utf-8") as f:
for word, prons in lexicon:
f.write(f"{word} {' '.join(prons)}\n")
@ -100,7 +100,7 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
Returns:
Return None.
"""
with open(filename, "w") as f:
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
@ -151,7 +151,8 @@ def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbols
- The ID of the max disambiguation symbols that appears
in the lexicon
"""
# (1) Work out the count of each phone-sequence in the

View File

@ -0,0 +1,196 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as inputs the following files:
- data/lang/bpe/bpe.model,
- data/lang/bpe/tokens.txt (will remove it),
- data/lang/bpe/words.txt
and generates the following files in the directory data/lang/bpe:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- phones.txt
"""
from pathlib import Path
from typing import Dict, List
import k2
import sentencepiece as spm
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
)
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go.
arcs = []
assert token2id["<blank>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, prons in lexicon:
assert len(prons) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
prons = [token2id[i] for i in prons]
for i in range(len(prons) - 1):
if i == 0:
arcs.append([cur_state, next_state, prons[i], word, 0])
else:
arcs.append([cur_state, next_state, prons[i], eps, 0])
cur_state = next_state
next_state += 1
# now for the last phone of this word
i = len(prons) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, prons[i], w, 0])
if need_self_loops:
disambig_phone = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_phone=disambig_phone,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def generate_lexicon(model_file: str, words: List[str]) -> Lexicon:
"""Generate a lexicon from a BPE model.
Args:
model_file:
Path to a sentencepiece model.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
word pieces.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
words_pieces: List[List[str]] = sp.encode(words, out_type=str)
lexicon = []
for word, pieces in zip(words, words_pieces):
lexicon.append((word, pieces))
lexicon.append(("<UNK>", ["<UNK>"]))
return lexicon
def main():
lang_dir = Path("data/lang/bpe")
model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
lexicon = generate_lexicon(model_file, words)
# TODO(fangjun): Remove tokens.txt and generate it from the model directly.
#
# We are using it since the IDs we are using in tokens.txt is
# different from the one contained in the model
token_sym_table = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table.add(f"#{i}")
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
token_sym_table.to_file(lang_dir / "phones.txt")
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(lang_dir / "phones.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(lang_dir / "L.svg", title="L")
L_disambig.draw(lang_dir / "L_disambig.svg", title="L_disambig")
if __name__ == "__main__":
main()

View File

@ -89,7 +89,17 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Stage 6: Prepare G"
echo "State 6: Prepare BPE based lang"
mkdir -p data/lang/bpe
cp data/lang/words.txt data/lang/bpe/
if [ ! -f data/lang/bpe/L_disambig.pt ]; then
./local/prepare_lang_bpe.py
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
@ -112,9 +122,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "Stage 7: Compile HLG"
if [ ! -f data/lm/HLG.pt ]; then
python3 ./local/compile_hlg.py
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Stage 8: Compile HLG"
python3 ./local/compile_hlg.py
fi

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
kaldilm
kaldialign
sentencepiece>=0.1.96