Add MMI training with word pieces.

This commit is contained in:
Fangjun Kuang 2021-08-07 16:41:16 +08:00
parent f03c991781
commit 897307f445
17 changed files with 4968 additions and 7 deletions

View File

@ -0,0 +1,918 @@
#!/usr/bin/env python3
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
import math
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):
"""
Args:
num_features (int): Number of input features
num_classes (int): Number of output classes
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
num_decoder_layers: int = 6,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
is_espnet_structure: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
num_classes=num_classes,
subsampling_factor=subsampling_factor,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
cnn_module_kernel,
normalize_before,
is_espnet_structure,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x:
The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
mask = encoder_padding_mask(x.size(0), supervisions)
if mask is not None:
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before and self.is_espnet_structure:
x = self.after_norm(x)
return x, mask
class ConformerEncoderLayer(nn.Module):
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
is_espnet_structure: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src
class ConformerEncoder(nn.TransformerEncoder):
r"""ConformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
) -> None:
super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
for mod in self.layers:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
if self.norm is not None:
output = self.norm(output)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_espnet_structure: bool = False,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
self.is_espnet_structure = is_espnet_structure
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
query,
key,
value,
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if not self.is_espnet_structure:
q = q * scaling
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
key_padding_mask.size(1), src_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
if not self.is_espnet_structure:
attn_output_weights = (
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
else:
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def identity(x):
return x

View File

@ -0,0 +1,507 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
# (still working in progress)
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from conformer import Conformer
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=9,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_mmi/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"use_feat_batchnorm": True,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
# Possible values for method:
# - 1best
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "1best",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 100,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is [N, T, C]
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph.
lexicon:
It contains word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeMmiTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt")
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
use_feat_batchnorm=params.use_feat_batchnorm,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
return x
class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
super().__init__()
cur_channels = 1
layers = []
block_dims = [32, 64]
# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python3
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
def test_conv2d_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = Conv2dSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim
def test_vgg_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = VggSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim

View File

@ -0,0 +1,89 @@
#!/usr/bin/env python3
import torch
from transformer import (
Transformer,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask():
supervisions = {
"sequence_idx": torch.tensor([0, 1, 2]),
"start_frame": torch.tensor([0, 0, 0]),
"num_frames": torch.tensor([18, 7, 13]),
}
max_len = ((18 - 1) // 2 - 1) // 2
mask = encoder_padding_mask(max_len, supervisions)
expected_mask = torch.tensor(
[
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_transformer():
num_features = 40
num_classes = 87
model = Transformer(num_features=num_features, num_classes=num_classes)
N = 31
for T in range(7, 30):
x = torch.rand(N, T, num_features)
y, _, _ = model(x)
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
def test_generate_square_subsequent_mask():
s = 5
mask = generate_square_subsequent_mask(s)
inf = float("inf")
expected_mask = torch.tensor(
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_decoder_padding_mask():
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
y = pad_sequence(x, batch_first=True, padding_value=-1)
mask = decoder_padding_mask(y, ignore_id=-1)
expected_mask = torch.tensor(
[
[False, False, True],
[False, True, True],
[False, False, False],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_add_sos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_sos(x, sos_id=0)
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
assert y == expected_y
def test_add_eos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_eos(x, eos_id=0)
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
assert y == expected_y

View File

@ -0,0 +1,688 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.mmi import LFMMILoss
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_mmi/exp"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
# It takes about 10 minutes (1 GPU, max_duration=200)
# to run a validation process.
# For the 100 h subset, there are 85617 batches.
# For the 960 h dataset, there are 843723 batches
"valid_interval": 8000,
"use_pruned_intersect": False,
"den_scale": 1.0,
#
"att_rate": 0.7,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
"is_espnet_structure": True,
"use_feat_batchnorm": True,
"lr_factor": 5.0,
"warm_step": 80000,
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: BpeMmiTrainingGraphCompiler,
is_training: bool,
):
"""
Compute MMI loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build num_graphs and den_graphs.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in LFMMILoss
#
# TODO: If params.use_pruned_intersect is True, there is no
# need to call encode_supervisions
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
loss_fn = LFMMILoss(
graph_compiler=graph_compiler,
den_scale=params.den_scale,
use_pruned_intersect=params.use_pruned_intersect,
)
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
if params.att_rate != 0.0:
token_ids = graph_compiler.texts_to_ids(texts)
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
else:
loss = mmi_loss
att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
assert loss.requires_grad == is_training
return loss, mmi_loss.detach(), att_loss.detach()
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: BpeMmiTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_mmi_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl):
loss, mmi_loss, att_loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
assert mmi_loss.requires_grad is False
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_mmi_loss += mmi_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
if world_size > 1:
s = torch.tensor(
[tot_loss, tot_mmi_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_mmi_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
params.valid_loss = tot_loss / tot_frames
params.valid_mmi_loss = tot_mmi_loss / tot_frames
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeMmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_mmi_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, mmi_loss, att_loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_cpu = loss.detach().cpu().item()
mmi_loss_cpu = mmi_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_mmi_loss += mmi_loss_cpu
tot_att_loss += att_loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_mmi_loss = tot_mmi_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg mmi loss: {tot_avg_mmi_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_mmi_loss",
mmi_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_mmi_loss",
tot_avg_mmi_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, "
f"valid mmi loss {params.valid_mmi_loss:.4f}, "
f"valid att loss {params.valid_att_loss:.4f}, "
f"valid loss {params.valid_loss:.4f}, "
f"best valid loss: {params.best_valid_loss:.4f}, "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_mmi_loss",
params.valid_mmi_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = tot_loss / tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = BpeMmiTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,976 @@
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
# Apache 2.0
import math
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import get_texts
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]
class Transformer(nn.Module):
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
num_decoder_layers: int = 6,
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
num_classes:
The output dimension of the model.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
Must satisfy d_model // nhead == 0.
dim_feedforward:
The output dimension of the feedforward layers in encoder/decoder.
num_encoder_layers:
Number of encoder layers.
num_decoder_layers:
Number of decoder layers.
dropout:
Dropout in encoder/decoder.
normalize_before:
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.num_classes = num_classes
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape [N, T, num_classes]
# to the shape [N, T//subsampling_factor, d_model].
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = PositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
encoder_norm = nn.LayerNorm(d_model)
else:
encoder_norm = None
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_encoder_layers,
norm=encoder_norm,
)
self.encoder_output_layer = nn.Linear(d_model, num_classes)
if num_decoder_layers > 0:
self.decoder_num_class = self.num_classes
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model
)
self.decoder_pos = PositionalEncoding(d_model, dropout)
decoder_layer = TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
decoder_norm = nn.LayerNorm(d_model)
else:
decoder_norm = None
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=num_decoder_layers,
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
else:
self.decoder_criterion = None
def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is [N, T, C].
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling)
Returns:
Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is [N, T, C]
- Encoder output with shape [T, N, C]. It can be used as key and
value for the decoder.
- Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is [N, T].
It is None if `supervision` is None.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
def run_encoder(
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Run the transformer encoder.
Args:
x:
The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute the encoder padding mask, which is used as memory key
padding mask for the decoder.
Returns:
Return a tuple with two tensors:
- The encoder output, with shape [T, N, C]
- encoder padding mask, with shape [N, T].
The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder.
"""
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervisions)
mask = mask.to(x.device) if mask is not None else None
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
return x, mask
def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The output tensor from the transformer encoder.
Its shape is [T, N, C]
Returns:
Return a tensor that can be used for CTC decoding.
Its shape is [N, T, C]
"""
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x
def decoder_forward(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape [T, N, C]
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance.
The IDs can be either phone IDs or word piece IDs.
sos_id:
sos token id
eos_id:
eos token id
Returns:
A scalar, the **sum** of label smoothing loss over utterances
in the batch without any normalization.
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, N, C)
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
return decoder_loss
def decoder_nll(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape [T, N, C]
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs (e.g., word piece IDs).
Each sublist represents an utterance.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A 2-D tensor of shape (len(token_ids), max_token_length)
representing the cross entropy loss (i.e., negative log-likelihood).
"""
# The common part between this function and decoder_forward could be
# extracted as a separate function.
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, B, F)
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
pred_pad.view(-1, self.decoder_num_class),
ys_out_pad.view(-1),
ignore_index=-1,
reduction="none",
)
nll = nll.view(pred_pad.shape[0], -1)
return nll
class TransformerEncoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerEncoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
normalize_before:
whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional)
Shape:
src: (S, N, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = src
if self.normalize_before:
src = self.norm1(src)
src2 = self.self_attn(
src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout1(src2)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src2)
if not self.normalize_before:
src = self.norm2(src)
return src
class TransformerDecoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerDecoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = decoder_layer(tgt, memory)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerDecoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerDecoderLayer, self).__setstate__(state)
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt:
the sequence to the decoder layer (required).
memory:
the sequence from the last layer of the encoder (required).
tgt_mask:
the mask for the tgt sequence (optional).
memory_mask:
the mask for the memory sequence (optional).
tgt_key_padding_mask:
the mask for the tgt keys per batch (optional).
memory_key_padding_mask:
the mask for the memory keys per batch (optional).
Shape:
tgt: (T, N, E).
memory: (S, N, E).
tgt_mask: (T, T).
memory_mask: (T, S).
tgt_key_padding_mask: (N, T).
memory_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt2 = self.self_attn(
tgt,
tgt,
tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)[0]
tgt = residual + self.dropout1(tgt2)
if not self.normalize_before:
tgt = self.norm1(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
tgt2 = self.src_attn(
tgt,
memory,
memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = residual + self.dropout2(tgt2)
if not self.normalize_before:
tgt = self.norm2(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout3(tgt2)
if not self.normalize_before:
tgt = self.norm3(tgt)
return tgt
def _get_activation_fn(activation: str):
if activation == "relu":
return nn.functional.relu
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module):
"""This class implements the positional encoding
proposed in the following paper:
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
Note::
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
= exp(-1* 2i / d_model * log(100000))
= exp(2i * -(log(10000) / d_model))
"""
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
"""
Args:
d_model:
Embedding dimension.
dropout:
Dropout probability to be applied to the output of this module.
"""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
self.pe = None
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
to [N, T, d_model]. Otherwise, nothing is done.
Args:
x:
It is a tensor of shape [N, T, C].
Returns:
Return None.
"""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Now pe is of shape [1, T, d_model], where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding.
Args:
x:
Its shape is [N, T, C]
Returns:
Return a tensor of shape [N, T, C]
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
return self.dropout(x)
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)
class LabelSmoothingLoss(nn.Module):
"""
Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
Args:
size: the number of class
padding_idx: padding_idx: ignored class id
smoothing: smoothing rate (0.0 means the conventional CE)
normalize_length: normalize loss by sequence length if True
criterion: loss function to be smoothed
"""
def __init__(
self,
size: int,
padding_idx: int = -1,
smoothing: float = 0.1,
normalize_length: bool = False,
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
) -> None:
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
assert 0.0 < smoothing <= 1.0
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.padding_id of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.size(2) == self.size
# batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
# denom = total if self.normalize_length else batch_size
denom = total if self.normalize_length else 1
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
def encoder_padding_mask(
max_len: int, supervisions: Optional[Supervisions] = None
) -> Optional[torch.Tensor]:
"""Make mask tensor containing indexes of padded part.
TODO::
This function **assumes** that the model uses
a subsampling factor of 4. We should remove that
assumption later.
Args:
max_len:
Maximum length of input features.
CAUTION: It is the length after subsampling.
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling)
Returns:
Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
"""
if supervisions is None:
return None
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"],
supervisions["num_frames"],
),
1,
).to(torch.int32)
lengths = [
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
start_frame = supervision_segments[idx, 1].item()
num_frames = supervision_segments[idx, 2].item()
lengths[sequence_idx] = start_frame + num_frames
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
bs = int(len(lengths))
seq_range = torch.arange(0, max_len, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
# Note: TorchScript doesn't implement Tensor.new()
seq_length_expand = torch.tensor(
lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def decoder_padding_mask(
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,
Unmasked positions are filled with False.
Args:
ys_pad:
padded tensor of dimension (batch_size, input_length).
ignore_id:
the ignored number (the padding number) in ys_pad
Returns:
Tensor:
a bool tensor of the same shape as the input tensor.
"""
ys_mask = ys_pad == ignore_id
return ys_mask
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
"""Generate a square mask for the sequence. The masked positions are
filled with float('-inf'). Unmasked positions are filled with float(0.0).
The mask can be used for masked self-attention.
For instance, if sz is 3, it returns::
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0]])
Args:
sz: mask size
Returns:
A square mask of dimension (sz, sz)
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
"""Prepend sos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
sos_id:
The ID of the SOS token.
Return:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
ans = []
for utt in token_ids:
ans.append([sos_id] + utt)
return ans
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
"""Append eos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
eos_id:
The ID of the EOS token.
Return:
Return a new list-of-list, where each sublist ends
with EOS ID.
"""
ans = []
for utt in token_ids:
ans.append(utt + [eos_id])
return ans

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
"""
Convert a transcript file containing words to a corpus file containing tokens
for LM training with the help of a lexicon.
If the lexicon contains phones, the resulting LM will be a phone LM; If the
lexicon contains word pieces, the resulting LM will be a word piece LM.
If a word has multiple pronunciations, the one that appears last in the lexicon
is used.
If the input transcript is:
hello zoo world hello
world zoo
foo zoo world hellO
and if the lexicon is
<UNK> SPN
hello h e l l o
hello h e l l o 2
world w o r l d
zoo z o o
Then the output is
h e l l o 2 z o o w o r l d h e l l o 2
w o r l d z o o
SPN z o o w o r l d SPN
"""
from pathlib import Path
from typing import Dict
import argparse
from icefall.lexicon import read_lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transcript",
type=str,
help="The input transcript file."
"We assume that the transcript file consists of "
"lines. Each line consists of space separated words.",
)
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
parser.add_argument(
"--oov", type=str, default="<UNK>", help="The OOV word."
)
return parser.parse_args()
def process_line(lexicon: Dict[str, str], line: str, oov_token: str) -> None:
"""
Args:
lexicon:
A dict containing pronunciations. Its keys are words and values
are pronunciations (i.e., tokens).
line:
A line of transcript consisting of space(s) separated words.
oov_token:
The pronunciation of the oov word if a word in `line` is not present
in the lexicon.
Returns:
Return None.
"""
s = ""
words = line.strip().split()
for i, w in enumerate(words):
tokens = lexicon.get(w, oov_token)
s += " ".join(tokens)
s += " "
print(s.strip())
def main():
args = get_args()
assert Path(args.lexicon).is_file()
assert Path(args.transcript).is_file()
assert len(args.oov) > 0
lexicon = dict(read_lexicon(args.lexicon))
assert args.oov in lexicon
oov_token = lexicon[args.oov]
with open(args.transcript) as f:
for line in f:
process_line(lexicon=lexicon, line=line, oov_token=oov_token)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,627 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang)
# Apache 2.0.
# This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
# in the same way as SRILM.
################################################
# Useful links/References:
################################################
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2124
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
# https://github.com/sfischer13/python-arpa
################################################
# How to use:
################################################
# python3 ngram_entropy_pruning.py -threshold $threshold -lm $input_lm -write-lm $pruned_lm
################################################
# SRILM commands:
################################################
# to_prune_lm=egs/swbd/s5c/data/local/lm/sw1.o3g.kn.gz
# vocab=egs/swbd/s5c/data/local/lm/wordlist
# order=3
# oov_symbol="<unk>"
# threshold=4.7e-5
# pruned_lm=temp.${threshold}.gz
# ngram -unk -map-unk "$oov_symbol" -vocab $vocab -order $order -prune ${threshold} -lm ${to_prune_lm} -write-lm ${pruned_lm}
#
# lm=
# ngram -unk -lm $lm -ppl heldout
# ngram -unk -lm $lm -ppl heldout -debug 3
import argparse
import logging
import math
import gzip
from io import StringIO
from collections import OrderedDict
from collections import defaultdict
from enum import Enum, unique
import re
parser = argparse.ArgumentParser(description="""
Prune an n-gram language model based on the relative entropy
between the original and the pruned model, based on Andreas Stolcke's paper.
An n-gram entry is removed, if the removal causes (training set) perplexity
of the model to increase by less than threshold relative.
The command takes an arpa file and a pruning threshold as input,
and outputs a pruned arpa file.
""")
parser.add_argument("-threshold",
type=float,
default=1e-6,
help="Order of n-gram")
parser.add_argument("-lm",
type=str,
default=None,
help="Path to the input arpa file")
parser.add_argument("-write-lm",
type=str,
default=None,
help="Path to output arpa file after pruning")
parser.add_argument("-minorder",
type=int,
default=1,
help="The minorder parameter limits pruning to "
"ngrams of that length and above.")
parser.add_argument("-encoding",
type=str,
default="utf-8",
help="Encoding of the arpa file")
parser.add_argument("-verbose",
type=int,
default=2,
choices=[0, 1, 2, 3, 4, 5],
help="Verbose level, where "
"0 is most noisy; "
"5 is most silent")
args = parser.parse_args()
default_encoding = args.encoding
logging.basicConfig(
format=
"%(asctime)s%(levelname)s%(funcName)s:%(lineno)d%(message)s",
level=args.verbose * 10)
class Context(dict):
"""
This class stores data for a context h.
It behaves like a python dict object, except that it has several
additional attributes.
"""
def __init__(self):
super().__init__()
self.log_bo = None
class Arpa:
"""
This is a class that implement the data structure of an APRA LM.
It (as well as some other classes) is modified based on the library
by Stefan Fischer:
https://github.com/sfischer13/python-arpa
"""
UNK = '<unk>'
SOS = '<s>'
EOS = '</s>'
FLOAT_NDIGITS = 7
base = 10
@staticmethod
def _check_input(my_input):
if not my_input:
raise ValueError
elif isinstance(my_input, tuple):
return my_input
elif isinstance(my_input, list):
return tuple(my_input)
elif isinstance(my_input, str):
return tuple(my_input.strip().split(' '))
else:
raise ValueError
@staticmethod
def _check_word(input_word):
if not isinstance(input_word, str):
raise ValueError
if ' ' in input_word:
raise ValueError
def _replace_unks(self, words):
return tuple((w if w in self else self._unk) for w in words)
def __init__(self, path=None, encoding=None, unk=None):
self._counts = OrderedDict()
self._ngrams = OrderedDict(
) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
self._vocabulary = set()
if unk is None:
self._unk = self.UNK
if path is not None:
self.loadf(path, encoding)
def __contains__(self, ngram):
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
def contains_word(self, word):
self._check_word(word)
return word in self._vocabulary
def add_count(self, order, count):
self._counts[order] = count
self._ngrams[order - 1] = defaultdict(Context)
def update_counts(self):
for order in range(1, self.order() + 1):
count = sum(
[len(wlist) for _, wlist in self._ngrams[order - 1].items()])
if count > 0:
self._counts[order] = count
def add_entry(self, ngram, p, bo=None, order=None):
# Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
# Note that p and bo here are in fact in the log domain (self.base = 10)
h_context = self._ngrams[len(h)][h]
h_context[w] = p
if bo is not None:
self._ngrams[len(ngram)][ngram].log_bo = bo
for word in ngram:
self._vocabulary.add(word)
def counts(self):
return sorted(self._counts.items())
def order(self):
return max(self._counts.keys(), default=None)
def vocabulary(self, sort=True):
if sort:
return sorted(self._vocabulary)
else:
return self._vocabulary
def _entries(self, order):
return (self._entry(h, w)
for h, wlist in self._ngrams[order - 1].items() for w in wlist)
def _entry(self, h, w):
# return the entry for the ngram (h, w)
ngram = h + (w, )
log_p = self._ngrams[len(h)][h][w]
log_bo = self._log_bo(ngram)
if log_bo is not None:
return round(log_p, self.FLOAT_NDIGITS), ngram, round(
log_bo, self.FLOAT_NDIGITS)
else:
return round(log_p, self.FLOAT_NDIGITS), ngram
def _log_bo(self, ngram):
if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
return self._ngrams[len(ngram)][ngram].log_bo
else:
return None
def _log_p(self, ngram):
h = ngram[:-1] # h is a tuple
w = ngram[-1] # w is a string/word
if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
return self._ngrams[len(h)][h][w]
else:
return None
def log_p_raw(self, ngram):
log_p = self._log_p(ngram)
if log_p is not None:
return log_p
else:
if len(ngram) == 1:
raise KeyError
else:
log_bo = self._log_bo(ngram[:-1])
if log_bo is None:
log_bo = 0
return log_bo + self.log_p_raw(ngram[1:])
def log_joint_prob(self, sequence):
# Compute the joint prob of the sequence based on the chain rule
# Note that sequence should be a tuple of strings
#
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
log_joint_p = 0
seq = sequence
while len(seq) > 0:
log_joint_p += self.log_p_raw(seq)
seq = seq[:-1]
# If we're computing the marginal probability of the unigram
# <s> context we have to look up </s> instead since the former
# has prob = 0.
if len(seq) == 1 and seq[0] == self.SOS:
seq = (self.EOS, )
return log_joint_p
def set_new_context(self, h):
old_context = self._ngrams[len(h)][h]
self._ngrams[len(h)][h] = Context()
return old_context
def log_p(self, ngram):
words = self._check_input(ngram)
if self._unk:
words = self._replace_unks(words)
return self.log_p_raw(words)
def log_s(self, sentence, sos=SOS, eos=EOS):
words = self._check_input(sentence)
if self._unk:
words = self._replace_unks(words)
if sos:
words = (sos, ) + words
if eos:
words = words + (eos, )
result = sum(
self.log_p_raw(words[:i]) for i in range(1,
len(words) + 1))
if sos:
result = result - self.log_p_raw(words[:1])
return result
def p(self, ngram):
return self.base**self.log_p(ngram)
def s(self, sentence):
return self.base**self.log_s(sentence)
def write(self, fp):
fp.write('\n\\data\\\n')
for order, count in self.counts():
fp.write('ngram {}={}\n'.format(order, count))
fp.write('\n')
for order, _ in self.counts():
fp.write('\\{}-grams:\n'.format(order))
for e in self._entries(order):
prob = e[0]
ngram = ' '.join(e[1])
if len(e) == 2:
fp.write('{}\t{}\n'.format(prob, ngram))
elif len(e) == 3:
backoff = e[2]
fp.write('{}\t{}\t{}\n'.format(prob, ngram, backoff))
else:
raise ValueError
fp.write('\n')
fp.write('\\end\\\n')
class ArpaParser:
"""
This is a class that implement a parser of an arpa file
"""
@unique
class State(Enum):
DATA = 1
COUNT = 2
HEADER = 3
ENTRY = 4
re_count = re.compile(r'^ngram (\d+)=(\d+)$')
re_header = re.compile(r'^\\(\d+)-grams:$')
re_entry = re.compile('^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)'
'\t'
'(\\S+( \\S+)*)'
'(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$')
def _parse(self, fp):
self._result = []
self._state = self.State.DATA
self._tmp_model = None
self._tmp_order = None
for line in fp:
line = line.strip()
if self._state == self.State.DATA:
self._data(line)
elif self._state == self.State.COUNT:
self._count(line)
elif self._state == self.State.HEADER:
self._header(line)
elif self._state == self.State.ENTRY:
self._entry(line)
if self._state != self.State.DATA:
raise Exception(line)
return self._result
def _data(self, line):
if line == '\\data\\':
self._state = self.State.COUNT
self._tmp_model = Arpa()
else:
pass # skip comment line
def _count(self, line):
match = self.re_count.match(line)
if match:
order = match.group(1)
count = match.group(2)
self._tmp_model.add_count(int(order), int(count))
elif not line:
self._state = self.State.HEADER # there are no counts
else:
raise Exception(line)
def _header(self, line):
match = self.re_header.match(line)
if match:
self._state = self.State.ENTRY
self._tmp_order = int(match.group(1))
elif line == '\\end\\':
self._result.append(self._tmp_model)
self._state = self.State.DATA
self._tmp_model = None
self._tmp_order = None
elif not line:
pass # skip empty line
else:
raise Exception(line)
def _entry(self, line):
match = self.re_entry.match(line)
if match:
p = self._float_or_int(match.group(1))
ngram = tuple(match.group(4).split(' '))
bo_match = match.group(7)
bo = self._float_or_int(bo_match) if bo_match else None
self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
elif not line:
self._state = self.State.HEADER # last entry
else:
raise Exception(line)
@staticmethod
def _float_or_int(s):
f = float(s)
i = int(f)
if str(i) == s: # don't drop trailing ".0"
return i
else:
return f
def load(self, fp):
"""Deserialize fp (a file-like object) to a Python object."""
return self._parse(fp)
def loadf(self, path, encoding=None):
"""Deserialize path (.arpa, .gz) to a Python object."""
path = str(path)
if path.endswith('.gz'):
with gzip.open(path, mode='rt', encoding=encoding) as f:
return self.load(f)
else:
with open(path, mode='rt', encoding=encoding) as f:
return self.load(f)
def loads(self, s):
"""Deserialize s (a str) to a Python object."""
with StringIO(s) as f:
return self.load(f)
def dump(self, obj, fp):
"""Serialize obj to fp (a file-like object) in ARPA format."""
obj.write(fp)
def dumpf(self, obj, path, encoding=None):
"""Serialize obj to path in ARPA format (.arpa, .gz)."""
path = str(path)
if path.endswith('.gz'):
with gzip.open(path, mode='wt', encoding=encoding) as f:
return self.dump(obj, f)
else:
with open(path, mode='wt', encoding=encoding) as f:
self.dump(obj, f)
def dumps(self, obj):
"""Serialize obj to an ARPA formatted str."""
with StringIO() as f:
self.dump(obj, f)
return f.getvalue()
def add_log_p(prev_log_sum, log_p, base):
return math.log(base**log_p + base**prev_log_sum, base)
def compute_numerator_denominator(lm, h):
log_sum_seen_h = -math.inf
log_sum_seen_h_lower = -math.inf
base = lm.base
for w, log_p in lm._ngrams[len(h)][h].items():
log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
ngram = h + (w, )
log_p_lower = lm.log_p_raw(ngram[1:])
log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower,
base)
numerator = 1.0 - base**log_sum_seen_h
denominator = 1.0 - base**log_sum_seen_h_lower
return numerator, denominator
def prune(lm, threshold, minorder):
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
for i in range(lm.order(), max(minorder - 1, 1),
-1): # i is the order of the ngram (h, w)
logging.info("processing %d-grams ..." % i)
count_pruned_ngrams = 0
h_dict = lm._ngrams[i - 1]
for h in list(h_dict.keys()):
# old backoff weight, BOW(h)
log_bow = lm._log_bo(h)
if log_bow is None:
log_bow = 0
# Compute numerator and denominator of the backoff weight,
# so that we can quickly compute the BOW adjustment due to
# leaving out one prob.
numerator, denominator = compute_numerator_denominator(lm, h)
# assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
# Compute the marginal probability of the context, P(h)
h_log_p = lm.log_joint_prob(h)
all_pruned = True
pruned_w_set = set()
for w, log_p in h_dict[h].items():
ngram = h + (w, )
# lower-order estimate for ngramProb, P(w|h')
backoff_prob = lm.log_p_raw(ngram[1:])
# Compute BOW after removing ngram, BOW'(h)
new_log_bow = math.log(numerator + lm.base ** log_p, lm.base) - \
math.log(denominator + lm.base ** backoff_prob, lm.base)
# Compute change in entropy due to removal of ngram
delta_prob = backoff_prob + new_log_bow - log_p
delta_entropy = - (lm.base ** h_log_p) * \
((lm.base ** log_p) * delta_prob +
numerator * (new_log_bow - log_bow))
# compute relative change in model (training set) perplexity
perp_change = lm.base**delta_entropy - 1.0
pruned = threshold > 0 and perp_change < threshold
# Make sure we don't prune ngrams whose backoff nodes are needed
if pruned and \
len(ngram) in lm._ngrams and \
len(lm._ngrams[len(ngram)][ngram]) > 0:
pruned = False
logging.debug("CONTEXT " + str(h) + " WORD " + w +
" CONTEXTPROB %f " % h_log_p +
" OLDPROB %f " % log_p + " NEWPROB %f " %
(backoff_prob + new_log_bow) +
" DELTA-H %f " % delta_entropy +
" DELTA-LOGP %f " % delta_prob +
" PPL-CHANGE %f " % perp_change + " PRUNED " +
str(pruned))
if pruned:
pruned_w_set.add(w)
count_pruned_ngrams += 1
else:
all_pruned = False
# If we removed all ngrams for this context we can
# remove the context itself, but only if the present
# context is not a prefix to a longer one.
if all_pruned and len(pruned_w_set) == len(h_dict[h]):
del h_dict[
h] # this context h is no longer needed, as its ngram prob is stored at its own context h'
elif len(pruned_w_set) > 0:
# The pruning for this context h is actually done here
old_context = lm.set_new_context(h)
for w, p_w in old_context.items():
if w not in pruned_w_set:
lm.add_entry(
h + (w, ),
p_w) # the entry hw is stored at the context h
# We need to recompute the back-off weight, but
# this can only be done after completing the pruning
# of the lower-order ngrams.
# Reference:
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
# recompute backoff weights
for i in range(max(minorder - 1, 1) + 1,
lm.order() +
1): # be careful of this order: from low- to high-order
for h in lm._ngrams[i - 1]:
numerator, denominator = compute_numerator_denominator(lm, h)
new_log_bow = math.log(numerator, lm.base) - math.log(
denominator, lm.base)
lm._ngrams[len(h)][h].log_bo = new_log_bow
# update counts
lm.update_counts()
return
def check_h_is_valid(lm, h):
sum_under_h = sum(
[lm.base**lm.log_p_raw(h + (w, )) for w in lm.vocabulary(sort=False)])
if abs(sum_under_h - 1.0) > 1e-6:
logging.info("warning: %s %f" % (str(h), sum_under_h))
return False
else:
return True
def validate_lm(lm):
# sanity check if the conditional probability sums to one under each context h
for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w)
logging.info("validating %d-grams ..." % i)
h_dict = lm._ngrams[i - 1]
for h in h_dict.keys():
check_h_is_valid(lm, h)
def compare_two_apras(path1, path2):
pass
if __name__ == '__main__':
# load an arpa file
logging.info("Loading the arpa file from %s" % args.lm)
parser = ArpaParser()
models = parser.loadf(args.lm, encoding=default_encoding)
lm = models[0] # ARPA files may contain several models.
logging.info("Stats before pruning:")
for i, cnt in lm.counts():
logging.info("ngram %d=%d" % (i, cnt))
# prune it, the language model will be modified in-place
logging.info("Start pruning the model with threshold=%.3E..." %
args.threshold)
prune(lm, args.threshold, args.minorder)
# validate_lm(lm)
# write the arpa language model to a file
logging.info("Stats after pruning:")
for i, cnt in lm.counts():
logging.info("ngram %d=%d" % (i, cnt))
logging.info("Saving the pruned arpa file to %s" % args.write_lm)
parser.dumpf(lm, args.write_lm, encoding=default_encoding)
logging.info("Done.")

View File

@ -143,7 +143,71 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare G"
log "Stage 7: Prepare bigram P"
if [ ! -f data/lang_bpe/corpus.txt ]; then
./local/convert_transcript_to_corpus.py \
--lexicon data/lang_bpe/lexicon.txt \
--transcript data/lang_bpe/train.txt \
--oov "<UNK>" \
> data/lang_bpe/corpus.txt
fi
if [ ! -f data/lang_bpe/P.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 2 \
-text data/lang_bpe/corpus.txt \
-lm data/lang_bpe/P.arpa
fi
# TODO: Use egs/wsj/s5/utils/lang/ngram_entropy_pruning.py
# from kaldi to prune P if it causes OOM later
if [ ! -f data/lang_bpe/P-no-prune.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="data/lang_bpe/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
data/lang_bpe/P.arpa > data/lang_bpe/P-no-prune.fst.txt
fi
thresholds=(
1e-6
1e-7
)
for threshold in ${thresholds[@]}; do
if [ ! -f data/lang_bpe/P-pruned.${threshold}.arpa ]; then
python3 ./local/ngram_entropy_pruning.py \
-threshold $threshold \
-lm data/lang_bpe/P.arpa \
-write-lm data/lang_bpe/P-pruned.${threshold}.arpa
fi
if [ ! -f data/lang_bpe/P-pruned.${threshold}.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="data/lang_bpe/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
data/lang_bpe/P-pruned.${threshold}.arpa > data/lang_bpe/P-pruned.${threshold}.fst.txt
fi
done
if [ ! -f data/lang_bpe/P-uni.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="data/lang_bpe/tokens.txt" \
--disambig-symbol='#0' \
--max-order=1 \
data/lang_bpe/P.arpa > data/lang_bpe/P-uni.fst.txt
fi
( cd data/lang_bpe;
# ln -sfv P-pruned.1e-6.fst.txt P.fst.txt
ln -sfv P-no-prune.fst.txt P.fst.txt
)
rm -fv data/lang_bpe/P.pt data/lang_bpe/ctc_topo_P.pt
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
@ -167,7 +231,7 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Compile HLG"
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile HLG"
python3 ./local/compile_hlg.py
fi

View File

@ -17,10 +17,14 @@ class BpeCtcTrainingGraphCompiler(object):
"""
Args:
lang_dir:
This directory is expected to contain the following files:
This directory is expected to contain the following files::
- bpe.model
- words.txt
The above files are produced by the script `prepare.sh`. You
should have run that before running the training code.
device:
It indicates CPU or CUDA.
sos_token:
@ -57,7 +61,9 @@ class BpeCtcTrainingGraphCompiler(object):
return self.sp.encode(texts, out_type=int)
def compile(
self, piece_ids: List[List[int]], modified: bool = False,
self,
piece_ids: List[List[int]],
modified: bool = False,
) -> k2.Fsa:
"""Build a ctc graph from a list-of-list piece IDs.

View File

@ -0,0 +1,178 @@
import logging
from pathlib import Path
from typing import List, Tuple, Union
import k2
import sentencepiece as spm
import torch
from icefall.lexicon import Lexicon
class BpeMmiTrainingGraphCompiler(object):
def __init__(
self,
lang_dir: Path,
device: Union[str, torch.device] = "cpu",
sos_token: str = "<sos/eos>",
eos_token: str = "<sos/eos>",
) -> None:
"""
Args:
lang_dir:
Path to the lang directory. It is expected to contain the
following files::
- tokens.txt
- words.txt
- bpe.model
- P.fst.txt
The above files are generated by the script `prepare.sh`. You
should have run it before running the training code.
device:
It indicates CPU or CUDA.
sos_token:
The word piece that represents sos.
eos_token:
The word piece that represents eos.
"""
self.lang_dir = Path(lang_dir)
self.lexicon = Lexicon(lang_dir)
self.device = device
self.load_sentence_piece_model()
self.build_ctc_topo_P()
self.sos_id = self.sp.piece_to_id(sos_token)
self.eos_id = self.sp.piece_to_id(eos_token)
assert self.sos_id != self.sp.unk_id()
assert self.eos_id != self.sp.unk_id()
def load_sentence_piece_model(self) -> None:
"""Load the pre-trained sentencepiece model
from self.lang_dir/bpe.model.
"""
model_file = self.lang_dir / "bpe.model"
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
self.sp = sp
def build_ctc_topo_P(self):
"""Built ctc_topo_P, the composition result of
ctc_topo and P, where P is a pre-trained bigram
word piece LM.
"""
# Note: there is no need to save a pre-compiled P and ctc_topo
# as it is very fast to generate them.
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
with open(self.lang_dir / "P.fst.txt") as f:
# P is not an acceptor because there is
# a back-off state, whose incoming arcs
# have label #0 and aux_label 0 (i.e., <eps>).
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = self.lexicon.token_table["#0"]
# P.aux_labels is not needed in later computations, so
# remove it here.
del P.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
P.labels[P.labels >= first_token_disambig_id] = 0
P = k2.remove_epsilon(P)
P = k2.arc_sort(P)
P = P.to(self.device)
# Add epsilon self-loops to P because we want the
# following operation "k2.intersect" to run on GPU.
P_with_self_loops = k2.add_epsilon_self_loops(P)
max_token_id = max(self.lexicon.tokens)
logging.info(
f"Building modified ctc_topo. max_token_id: {max_token_id}"
)
# CAUTION: We have to use a modifed version of CTC topo.
# Otherwise, the resulting ctc_topo_P is so large that it gets
# stuck in k2.intersect_dense_pruned() or it gets OOM in
# k2.intersect_dense()
ctc_topo = k2.ctc_topo(max_token_id, modified=True, device=self.device)
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
logging.info("Building ctc_topo_P")
ctc_topo_P = k2.intersect(
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
).invert()
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of piece IDs.
Args:
texts:
A list of transcripts. Within a transcript words are
separated by spaces. An example input is::
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of piece IDs.
"""
return self.sp.encode(texts, out_type=int)
def compile(
self, texts: List[str], replicate_den: bool = True
) -> Tuple[k2.Fsa, k2.Fsa]:
"""Create numerator and denominator graphs from transcripts.
Args:
texts:
A list of transcripts. Within a transcript words are
separated by spaces. An example input is::
["HELLO icefall", "HALLO WELT"]
replicate_den:
If True, the returned den_graph is replicated to match the number
of FSAs in the returned num_graph; if False, the returned den_graph
contains only a single FSA
Returns:
A tuple (num_graphs, den_graphs), where
- `num_graphs` is the numerator graph. It is an FsaVec with
shape `(len(texts), None, None)`.
- `den_graphs` is the denominator graph. It is an FsaVec with the
same shape of the `num_graph` if replicate_den is True;
otherwise, it is an FsaVec containing only a single FSA.
"""
token_ids = self.texts_to_ids(texts)
token_fsas = k2.linear_fsa(token_ids, device=self.device)
token_fsas_with_self_loops = k2.add_epsilon_self_loops(token_fsas)
# NOTE: Use treat_epsilons_specially=False so that k2.compose
# can be run on GPU
num_graphs = k2.compose(
self.ctc_topo_P,
token_fsas_with_self_loops,
treat_epsilons_specially=False,
)
# num_graphs may not be connected and
# not be topologically sorted after k2.compose
num_graphs = k2.connect(num_graphs)
num_graphs = k2.top_sort(num_graphs)
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P.detach()])
if replicate_den:
indexes = torch.zeros(
len(texts), dtype=torch.int32, device=self.device
)
den_graphs = k2.index_fsa(ctc_topo_P_vec, indexes)
else:
den_graphs = ctc_topo_P_vec
return num_graphs, den_graphs

View File

@ -78,11 +78,13 @@ class Lexicon(object):
"""
Args:
lang_dir:
Path to the lang director. It is expected to contain the following
files:
Path to the lang directory. It is expected to contain the following
files::
- tokens.txt
- words.txt
- L.pt
The above files are produced by the script `prepare.sh`. You
should have run that before running the training code.
disambig_pattern:

222
icefall/mmi.py Normal file
View File

@ -0,0 +1,222 @@
from typing import List
import k2
import torch
from torch import nn
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
def _compute_mmi_loss_exact_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
The function name contains `exact`, which means it uses a version of
intersection without pruning.
`optimized` in the function name means this function is optimized
in that it calls k2.intersect_dense only once
Note:
It is faster at the cost of using more memory.
Args:
dense_fsa_vec:
It contains the neural network output.
texts:
The transcript. Each element consists of space(s) separated words.
graph_compiler:
Used to build num_graphs and den_graphs
den_scale:
The scale applied to the denominator tot_scores.
Returns:
Return a scalar loss. It is the sum over utterances in a batch,
without normalization.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
device = num_graphs.device
num_fsas = num_graphs.shape[0]
assert dense_fsa_vec.dim0() == num_fsas
assert den_graphs.shape[0] == 1
# The motivation to concatenate num_graphs and den_graphs
# is to reduce the number of calls to k2.intersect_dense.
num_den_graphs = k2.cat([num_graphs, den_graphs])
# NOTE: The a_to_b_map in k2.intersect_dense must be sorted
# so the following reorders num_den_graphs.
#
# The following code computes a_to_b_map
# [0, 1, 2, ... ]
num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)
# [num_fsas, num_fsas, num_fsas, ... ]
den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)
# [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
num_den_graphs_indexes = (
torch.stack([num_graphs_indexes, den_graphs_indexes])
.t()
.reshape(-1)
.to(device)
)
num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
# [[0, 1, 2, ...]]
a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)
# [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)
num_den_lats = k2.intersect_dense(
num_den_reordered_graphs,
dense_fsa_vec,
output_beam=10.0,
a_to_b_map=a_to_b_map,
)
num_den_tot_scores = num_den_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
num_tot_scores = num_den_tot_scores[::2]
den_tot_scores = num_den_tot_scores[1::2]
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
def _compute_mmi_loss_exact_non_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
of the arguments.
It's more readable, though it invokes k2.intersect_dense twice.
Note:
It uses less memory at the cost of speed. It is slower.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
# TODO: pass output_beam as function argument
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=10.0)
num_tot_scores = num_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
den_tot_scores = den_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
def _compute_mmi_loss_pruned(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
of the arguments.
`pruned` means it uses k2.intersect_dense_pruned
Note:
It uses the least amount of memory, but the loss is not exact due
to pruning.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
# the values for search_beam/output_beam/min_active_states/max_active_states
# are not tuned. You may want to tune them.
den_lats = k2.intersect_dense_pruned(
den_graphs,
dense_fsa_vec,
search_beam=20.0,
output_beam=8.0,
min_active_states=30,
max_active_states=10000,
)
num_tot_scores = num_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
den_tot_scores = den_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
class LFMMILoss(nn.Module):
"""
Computes Lattice-Free Maximum Mutual Information (LFMMI) loss.
TODO: more detailed description
"""
def __init__(
self,
graph_compiler: BpeMmiTrainingGraphCompiler,
use_pruned_intersect: bool = False,
den_scale: float = 1.0,
):
super().__init__()
self.graph_compiler = graph_compiler
self.den_scale = den_scale
self.use_pruned_intersect = use_pruned_intersect
def forward(
self,
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
) -> torch.Tensor:
"""
Args:
dense_fsa_vec:
It contains the neural network output.
texts:
A list of strings. Each string contains space(s) separated words.
Returns:
Return a scalar loss. It is the sum over utterances in a batch,
without normalization.
"""
if self.use_pruned_intersect:
func = _compute_mmi_loss_pruned
else:
func = _compute_mmi_loss_exact_non_optimized
# func = _compute_mmi_loss_exact_optimized
return func(
dense_fsa_vec=dense_fsa_vec,
texts=texts,
graph_compiler=self.graph_compiler,
den_scale=self.den_scale,
)

377
icefall/shared/make_kn_lm.py Executable file
View File

@ -0,0 +1,377 @@
#!/usr/bin/env python3
# Copyright 2016 Johns Hopkins University (Author: Daniel Povey)
# 2018 Ruizhe Huang
# Apache 2.0.
# This is an implementation of computing Kneser-Ney smoothed language model
# in the same way as srilm. This is a back-off, unmodified version of
# Kneser-Ney smoothing, which produces the same results as the following
# command (as an example) of srilm:
#
# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \
# -text corpus.txt -lm lm.arpa
#
# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
import sys
import os
import re
import io
import math
import argparse
from collections import Counter, defaultdict
parser = argparse.ArgumentParser(description="""
Generate kneser-ney language model as arpa format. By default,
it will read the corpus from standard input, and output to standard output.
""")
parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
args = parser.parse_args()
default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input.
# Need to be very careful about the use of strip() and split()
# in this case, because there is a latin-1 whitespace character
# (nbsp) which is part of the unicode encoding range.
# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
strip_chars = " \t\r\n"
whitespace = re.compile("[ \t]+")
class CountsForHistory:
# This class (which is more like a struct) stores the counts seen in a
# particular history-state. It is used inside class NgramCounts.
# It really does the job of a dict from int to float, but it also
# keeps track of the total count.
def __init__(self):
# The 'lambda: defaultdict(float)' is an anonymous function taking no
# arguments that returns a new defaultdict(float).
self.word_to_count = defaultdict(int)
self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts
self.word_to_f = dict() # discounted probability
self.word_to_bow = dict() # back-off weight
self.total_count = 0
def words(self):
return self.word_to_count.keys()
def __str__(self):
# e.g. returns ' total=12: 3->4, 4->6, -1->2'
return ' total={0}: {1}'.format(
str(self.total_count),
', '.join(['{0} -> {1}'.format(word, count)
for word, count in self.word_to_count.items()]))
def add_count(self, predicted_word, context_word, count):
assert count >= 0
self.total_count += count
self.word_to_count[predicted_word] += count
if context_word is not None:
self.word_to_context[predicted_word].add(context_word)
class NgramCounts:
# A note on data-structure. Firstly, all words are represented as
# integers. We store n-gram counts as an array, indexed by (history-length
# == n-gram order minus one) (note: python calls arrays "lists") of dicts
# from histories to counts, where histories are arrays of integers and
# "counts" are dicts from integer to float. For instance, when
# accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
# do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
# array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
def __init__(self, ngram_order, bos_symbol='<s>', eos_symbol='</s>'):
assert ngram_order >= 2
self.ngram_order = ngram_order
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
self.counts = []
for n in range(ngram_order):
self.counts.append(defaultdict(lambda: CountsForHistory()))
self.d = [] # list of discounting factor for each order of ngram
# adds a raw count (called while processing input data).
# Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history'
# would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
# 1.
def add_count(self, history, predicted_word, context_word, count):
self.counts[len(history)][history].add_count(predicted_word, context_word, count)
# 'line' is a string containing a sequence of integer word-ids.
# This function adds the un-smoothed counts from this line of text.
def add_raw_counts_from_line(self, line):
if line == '':
words = [self.bos_symbol, self.eos_symbol]
else:
words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
for i in range(len(words)):
for n in range(1, self.ngram_order+1):
if i + n > len(words):
break
ngram = words[i: i + n]
predicted_word = ngram[-1]
history = tuple(ngram[: -1])
if i == 0 or n == self.ngram_order:
context_word = None
else:
context_word = words[i-1]
self.add_count(history, predicted_word, context_word, 1)
def add_raw_counts_from_standard_input(self):
lines_processed = 0
infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input
for line in infile:
line = line.strip(strip_chars)
self.add_raw_counts_from_line(line)
lines_processed += 1
if lines_processed == 0 or args.verbose > 0:
print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
def add_raw_counts_from_file(self, filename):
lines_processed = 0
with open(filename, encoding=default_encoding) as fp:
for line in fp:
line = line.strip(strip_chars)
self.add_raw_counts_from_line(line)
lines_processed += 1
if lines_processed == 0 or args.verbose > 0:
print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
def cal_discounting_constants(self):
# For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
# where n1_N is the number of unique N-grams with count = 1 (counts-of-counts).
# This constant is used similarly to absolute discounting.
# Return value: d is a list of floats, where d[N+1] = D_N
self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
# This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
# but perhaps this is not the case for some other scenarios.
for n in range(1, self.ngram_order):
this_order_counts = self.counts[n]
n1 = 0
n2 = 0
for hist, counts_for_hist in this_order_counts.items():
stat = Counter(counts_for_hist.word_to_count.values())
n1 += stat[1]
n2 += stat[2]
assert n1 + 2 * n2 > 0
self.d.append(n1 * 1.0 / (n1 + 2 * n2))
def cal_f(self):
# f(a_z) is a probability distribution of word sequence a_z.
# Typically f(a_z) is discounted to be less than the ML estimate so we have
# some leftover probability for the z words unseen in the context (a_).
#
# f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams
# f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams
# highest order N-grams
n = self.ngram_order - 1
this_order_counts = self.counts[n]
for hist, counts_for_hist in this_order_counts.items():
for w, c in counts_for_hist.word_to_count.items():
counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
# lower order N-grams
for n in range(0, self.ngram_order - 1):
this_order_counts = self.counts[n]
for hist, counts_for_hist in this_order_counts.items():
n_star_star = 0
for w in counts_for_hist.word_to_count.keys():
n_star_star += len(counts_for_hist.word_to_context[w])
if n_star_star != 0:
for w in counts_for_hist.word_to_count.keys():
n_star_z = len(counts_for_hist.word_to_context[w])
counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
else: # patterns begin with <s>, they do not have "modified count", so use raw count instead
for w in counts_for_hist.word_to_count.keys():
n_star_z = counts_for_hist.word_to_count[w]
counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
def cal_bow(self):
# Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
# Thus, two sorts of ngrams do not have a bow:
# 1) highest order ngram
# 2) ngrams ending in </s>
#
# bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z))
# Note that Z1 is the set of all words with c(a_z) > 0
# highest order N-grams
n = self.ngram_order - 1
this_order_counts = self.counts[n]
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
counts_for_hist.word_to_bow[w] = None
# lower order N-grams
for n in range(0, self.ngram_order - 1):
this_order_counts = self.counts[n]
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
if w == self.eos_symbol:
counts_for_hist.word_to_bow[w] = None
else:
a_ = hist + (w,)
assert len(a_) < self.ngram_order
assert a_ in self.counts[len(a_)].keys()
a_counts_for_hist = self.counts[len(a_)][a_]
sum_z1_f_a_z = 0
for u in a_counts_for_hist.word_to_count.keys():
sum_z1_f_a_z += a_counts_for_hist.word_to_f[u]
sum_z1_f_z = 0
_ = a_[1:]
_counts_for_hist = self.counts[len(_)][_]
for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1
sum_z1_f_z += _counts_for_hist.word_to_f[u]
counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
def print_raw_counts(self, info_string):
# these are useful for debug.
print(info_string)
res = []
for this_order_counts in self.counts:
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
ngram = " ".join(hist) + " " + w
ngram = ngram.strip(strip_chars)
res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
res.sort(reverse=True)
for r in res:
print(r)
def print_modified_counts(self, info_string):
# these are useful for debug.
print(info_string)
res = []
for this_order_counts in self.counts:
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
ngram = " ".join(hist) + " " + w
ngram = ngram.strip(strip_chars)
modified_count = len(counts_for_hist.word_to_context[w])
raw_count = counts_for_hist.word_to_count[w]
if modified_count == 0:
res.append("{0}\t{1}".format(ngram, raw_count))
else:
res.append("{0}\t{1}".format(ngram, modified_count))
res.sort(reverse=True)
for r in res:
print(r)
def print_f(self, info_string):
# these are useful for debug.
print(info_string)
res = []
for this_order_counts in self.counts:
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
ngram = " ".join(hist) + " " + w
ngram = ngram.strip(strip_chars)
f = counts_for_hist.word_to_f[w]
if f == 0: # f(<s>) is always 0
f = 1e-99
res.append("{0}\t{1}".format(ngram, math.log(f, 10)))
res.sort(reverse=True)
for r in res:
print(r)
def print_f_and_bow(self, info_string):
# these are useful for debug.
print(info_string)
res = []
for this_order_counts in self.counts:
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
ngram = " ".join(hist) + " " + w
ngram = ngram.strip(strip_chars)
f = counts_for_hist.word_to_f[w]
if f == 0: # f(<s>) is always 0
f = 1e-99
bow = counts_for_hist.word_to_bow[w]
if bow is None:
res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
else:
res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
res.sort(reverse=True)
for r in res:
print(r)
def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
# print as ARPA format.
print('\\data\\', file=fout)
for hist_len in range(self.ngram_order):
# print the number of n-grams.
print('ngram {0}={1}'.format(
hist_len + 1,
sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
file=fout
)
print('', file=fout)
for hist_len in range(self.ngram_order):
print('\\{0}-grams:'.format(hist_len + 1), file=fout)
this_order_counts = self.counts[hist_len]
for hist, counts_for_hist in this_order_counts.items():
for word in counts_for_hist.word_to_count.keys():
ngram = hist + (word,)
prob = counts_for_hist.word_to_f[word]
bow = counts_for_hist.word_to_bow[word]
if prob == 0: # f(<s>) is always 0
prob = 1e-99
line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
if bow is not None:
line += '\t{0}'.format('%.7f' % math.log10(bow))
print(line, file=fout)
print('', file=fout)
print('\\end\\', file=fout)
if __name__ == "__main__":
ngram_counts = NgramCounts(args.ngram_order)
if args.text is None:
ngram_counts.add_raw_counts_from_standard_input()
else:
assert os.path.isfile(args.text)
ngram_counts.add_raw_counts_from_file(args.text)
ngram_counts.cal_discounting_constants()
ngram_counts.cal_f()
ngram_counts.cal_bow()
if args.lm is None:
ngram_counts.print_as_arpa()
else:
with open(args.lm, 'w', encoding=default_encoding) as f:
ngram_counts.print_as_arpa(fout=f)

View File

@ -0,0 +1,30 @@
#!/usr/bin/env python3
import copy
import logging
from pathlib import Path
import k2
import torch
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
def test_bpe_mmi_graph_compiler():
lang_dir = Path("data/lang_bpe")
if lang_dir.is_dir() is False:
return
device = torch.device("cpu")
compiler = BpeMmiTrainingGraphCompiler(lang_dir, device=device)
texts = ["HELLO WORLD", "MMI TRAINING"]
num_graphs, den_graphs = compiler.compile(texts)
num_graphs.labels_sym = compiler.lexicon.token_table
num_graphs.aux_labels_sym = copy.deepcopy(compiler.lexicon.token_table)
num_graphs.aux_labels_sym._id2sym[0] = "<eps>"
num_graphs[0].draw("num_graphs_0.svg", title="HELLO WORLD")
num_graphs[1].draw("num_graphs_1.svg", title="HELLO WORLD")
print(den_graphs.shape)
print(den_graphs[0].shape)
print(den_graphs[0].num_arcs)