Merge d4440b421c3700982c7f9e20bea37e48fb144b7e into fee1f84b20a5a704428c5eac80de2ac4033e1b27

This commit is contained in:
Fangjun Kuang 2021-10-15 20:39:33 +08:00 committed by GitHub
commit 8ae39529c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 5722 additions and 375 deletions

View File

@ -4,8 +4,10 @@ statistics=true
max-line-length = 80
per-file-ignores =
# line too long
egs/librispeech/ASR/conformer_ctc/conformer.py: E501,
egs/librispeech/ASR/*/conformer.py: E501,
exclude =
.git,
**/data/**
**/data/**,
icefall/shared/make_kn_lm.py,
egs/librispeech/ASR/conformer_mmi_phone/embedding.py

View File

@ -90,7 +90,7 @@ def get_parser():
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
help="""Resume training from this epoch.
If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt
""",

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -0,0 +1,933 @@
#!/usr/bin/env python3
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
class Conformer(Transformer):
"""
Args:
num_features (int): Number of input features
num_classes (int): Number of output classes
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
num_decoder_layers: int = 6,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
is_espnet_structure: bool = False,
is_bpe: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
num_classes=num_classes,
subsampling_factor=subsampling_factor,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
is_bpe=is_bpe,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
cnn_module_kernel,
normalize_before,
is_espnet_structure,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model)
else:
# Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change.
self.after_norm = identity
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x:
The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
mask = encoder_padding_mask(x.size(0), supervisions)
if mask is not None:
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before and self.is_espnet_structure:
x = self.after_norm(x)
return x, mask
class ConformerEncoderLayer(nn.Module):
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
is_espnet_structure: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_ff_macaron = nn.LayerNorm(
d_model
) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(
d_model
) # for the final output of the block
self.dropout = nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src
class ConformerEncoder(nn.TransformerEncoder):
r"""ConformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
) -> None:
super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
for mod in self.layers:
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
if self.norm is not None:
output = self.norm(output)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_espnet_structure: bool = False,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
self.is_espnet_structure = is_espnet_structure
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
query,
key,
value,
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if not self.is_espnet_structure:
q = q * scaling
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError(
"The size of the 2D attn_mask is not correct."
)
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError(
"The size of the 3D attn_mask is not correct."
)
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(
attn_mask.dim()
)
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if (
key_padding_mask is not None
and key_padding_mask.dtype == torch.uint8
):
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
key_padding_mask.size(1), src_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
if not self.is_espnet_structure:
attn_output_weights = (
matrix_ac + matrix_bd
) # (batch, head, time1, time2)
else:
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def identity(x):
return x

View File

@ -0,0 +1,537 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
# (still working in progress)
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=9,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
)
parser.add_argument(
"--num-paths",
type=int,
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=0.5,
help="The scale to be applied to `lattice.scores`."
"It's needed if you use any kinds of n-best based rescoring. "
"Currently, it is used when the decoding method is: nbest, "
"nbest-rescoring, attention-decoder, and nbest-oracle. "
"A smaller value results in more unique paths.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_mmi_phone/exp"),
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
# Possible values for method:
# - 1best
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
# - attention-decoder
# - nbest-oracle
# "method": "nbest",
# "method": "nbest-rescoring",
# "method": "whole-lattice-rescoring",
# "method": "attention-decoder",
# "method": "nbest-oracle",
# num_paths is used when method is "nbest", "nbest-rescoring",
# attention-decoder, and nbest-oracle
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains word symbol table.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is [N, T, C]
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
return nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
lexicon=lexicon,
scale=params.lattice_score_scale,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
scale=params.lattice_score_scale,
)
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
scale=params.lattice_score_scale,
sos_id=params.sos_id,
eos_id=params.eos_id,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph.
lexicon:
It contains word symbol table.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
lexicon=lexicon,
G=G,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
is_bpe=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
assert model.decoder_num_class == num_classes + 1
params.sos_id = num_classes
params.eos_id = num_classes
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
lexicon=lexicon,
G=G,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,221 @@
# This file is copied & modified from pytorch/torch/nn/modules/sparse.py
# It modifies nn.Embedding
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
class Embedding(nn.Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
i.e. it remains as a fixed "pad". For a newly constructed Embedding,
the embedding vector at :attr:`padding_idx` will default to all zeros,
but can be updated to another value to be used as the padding vector.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
See Notes for more details regarding sparse gradients.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
.. note::
Keep in mind that only a limited number of optimizers support
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
.. note::
When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
:attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
modified in-place, performing a differentiable operation on ``Embedding.weight`` before
calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
:attr:`max_norm` is not ``None``. For example::
n, d, m = 3, 5, 7
embedding = nn.Embedding(n, d, max_norm=True)
W = torch.randn((m, d), requires_grad=True)
idx = torch.tensor([1, 2])
a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable
b = embedding(idx) @ W.t() # modifies weight in-place
out = (a.unsqueeze(0) + b.unsqueeze(1))
loss = out.sigmoid().prod()
loss.backward()
Examples::
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000, 0.0000, 0.0000],
[ 0.1535, -2.0309, 0.9315],
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
>>> # example of changing `pad` vector
>>> padding_idx = 0
>>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
>>> embedding.weight
Parameter containing:
tensor([[ 0.0000, 0.0000, 0.0000],
[-0.7895, -0.7089, -0.0364],
[ 0.6778, 0.5803, 0.2678]], requires_grad=True)
>>> with torch.no_grad():
... embedding.weight[padding_idx] = torch.ones(3)
>>> embedding.weight
Parameter containing:
tensor([[ 1.0000, 1.0000, 1.0000],
[-0.7895, -0.7089, -0.0364],
[ 0.6778, 0.5803, 0.2678]], requires_grad=True)
"""
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
'norm_type', 'scale_grad_by_freq', 'sparse']
num_embeddings: int
embedding_dim: int
padding_idx: Optional[int]
max_norm: Optional[float]
norm_type: float
scale_grad_by_freq: bool
weight: Tensor
sparse: bool
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.embedding_scale = math.sqrt(self.embedding_dim)
if _weight is None:
self.weight = Parameter(torch.empty(num_embeddings, embedding_dim))
self.reset_parameters()
else:
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
'Shape of weight does not match num_embeddings and embedding_dim'
self.weight = Parameter(_weight)
self.sparse = sparse
def reset_parameters(self) -> None:
std = 1 / self.embedding_scale
nn.init.normal_(self.weight, std=std)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
return F.embedding(
input, self.weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse) * self.embedding_scale
def extra_repr(self) -> str:
s = '{num_embeddings}, {embedding_dim}'
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
if self.max_norm is not None:
s += ', max_norm={max_norm}'
if self.norm_type != 2:
s += ', norm_type={norm_type}'
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
if self.sparse is not False:
s += ', sparse=True'
return s.format(**self.__dict__)
@classmethod
def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
max_norm=None, norm_type=2., scale_grad_by_freq=False,
sparse=False):
r"""Creates Embedding instance from given 2-dimensional FloatTensor.
Args:
embeddings (Tensor): FloatTensor containing weights for the Embedding.
First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
i.e. it remains as a fixed "pad".
max_norm (float, optional): See module initialization documentation.
norm_type (float, optional): See module initialization documentation. Default ``2``.
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
sparse (bool, optional): See module initialization documentation.
Examples::
>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000, 5.1000, 6.3000]])
"""
assert embeddings.dim() == 2, \
'Embeddings parameter is expected to be 2-dimensional'
rows, cols = embeddings.shape
embedding = cls(
num_embeddings=rows,
embedding_dim=cols,
_weight=embeddings,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
embedding.weight.requires_grad = not freeze
return embedding

View File

@ -0,0 +1,350 @@
#!/usr/bin/env python3
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
(3) attention-decoder - Extract n paths from he rescored
lattice and use the transformer attention decoder for
rescoring.
We call it HLG decoding + n-gram LM rescoring + attention
decoder rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring or attention-decoder.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""
Used only when method is attention-decoder.
It specifies the size of n-best list.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=1.3,
help="""
Used only when method is whole-lattice-rescoring and attention-decoder.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--attention-decoder-scale",
type=float,
default=1.2,
help="""
Used only when method is attention-decoder.
It specifies the scale for attention decoder scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=0.5,
help="""
Used only when method is attention-decoder.
It specifies the scale for lattice.scores when
extracting n-best lists. A smaller value results in
more unique number of paths with the risk of missing
the best path.
""",
)
parser.add_argument(
"--sos-id",
type=float,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the SOS token.
""",
)
parser.add_argument(
"--eos-id",
type=float,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the EOS token.
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"sample_rate": 16000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = G.to(device)
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
nnet_output, memory, memory_key_padding_mask = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
return x
class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
super().__init__()
cur_channels = 1
layers = []
block_dims = [32, 64]
# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python3
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
def test_conv2d_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = Conv2dSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim
def test_vgg_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = VggSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim

View File

@ -0,0 +1,89 @@
#!/usr/bin/env python3
import torch
from transformer import (
Transformer,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask():
supervisions = {
"sequence_idx": torch.tensor([0, 1, 2]),
"start_frame": torch.tensor([0, 0, 0]),
"num_frames": torch.tensor([18, 7, 13]),
}
max_len = ((18 - 1) // 2 - 1) // 2
mask = encoder_padding_mask(max_len, supervisions)
expected_mask = torch.tensor(
[
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_transformer():
num_features = 40
num_classes = 87
model = Transformer(num_features=num_features, num_classes=num_classes)
N = 31
for T in range(7, 30):
x = torch.rand(N, T, num_features)
y, _, _ = model(x)
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
def test_generate_square_subsequent_mask():
s = 5
mask = generate_square_subsequent_mask(s)
inf = float("inf")
expected_mask = torch.tensor(
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_decoder_padding_mask():
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
y = pad_sequence(x, batch_first=True, padding_value=-1)
mask = decoder_padding_mask(y, ignore_id=-1)
expected_mask = torch.tensor(
[
[False, False, True],
[False, True, True],
[False, False, False],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_add_sos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_sos(x, sos_id=0)
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
assert y == expected_y
def test_add_eos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_eos(x, eos_id=0)
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
assert y == expected_y

View File

@ -0,0 +1,734 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
import k2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.mmi import LFMMILoss
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=50,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from this epoch.
If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_mmi_phone/exp"),
"lang_dir": Path("data/lang_phone"),
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000,
"beam_size": 10,
"use_pruned_intersect": False,
"den_scale": 1.0,
#
"att_rate": 0.7, # If not zero, use attention decoder
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
"is_espnet_structure": True,
"use_feat_batchnorm": True,
"lr_factor": 5.0,
"warm_step": 80000,
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: MmiTrainingGraphCompiler,
is_training: bool,
):
"""
Compute MMI loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build num_graphs and den_graphs.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in LFMMILoss
#
# TODO: If params.use_pruned_intersect is True, there is no
# need to call encode_supervisions
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
loss_fn = LFMMILoss(
graph_compiler=graph_compiler,
output_beam=params.beam_size,
den_scale=params.den_scale,
use_pruned_intersect=params.use_pruned_intersect,
)
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
if params.att_rate != 0.0:
token_ids = graph_compiler.texts_to_ids(texts)
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=params.sos_id,
eos_id=params.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=params.sos_id,
eos_id=params.eos_id,
)
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
else:
loss = mmi_loss
att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
assert loss.requires_grad == is_training
return loss, mmi_loss.detach(), att_loss.detach()
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: MmiTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_mmi_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl):
loss, mmi_loss, att_loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
assert mmi_loss.requires_grad is False
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_mmi_loss += mmi_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
if world_size > 1:
s = torch.tensor(
[tot_loss, tot_mmi_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_mmi_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
params.valid_loss = tot_loss / tot_frames
params.valid_mmi_loss = tot_mmi_loss / tot_frames
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: MmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_mmi_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, mmi_loss, att_loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
mmi_loss_cpu = mmi_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_mmi_loss += mmi_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_mmi_loss = tot_mmi_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0:
total_duration = 0.0
max_cut_duration = 0.0
for cut in batch["supervisions"]["cut"]:
total_duration += cut.duration
max_cut_duration = max(max_cut_duration, cut.duration)
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg mmi loss: {tot_avg_mmi_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}, "
f"total duration: {total_duration:.3f} s, "
f"max cut duration: {max_cut_duration:.3f} s"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_mmi_loss",
mmi_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_mmi_loss",
tot_avg_mmi_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_mmi_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, "
f"valid mmi loss {params.valid_mmi_loss:.4f}, "
f"valid att loss {params.valid_att_loss:.4f}, "
f"valid loss {params.valid_loss:.4f}, "
f"best valid loss: {params.best_valid_loss:.4f}, "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_mmi_loss",
params.valid_mmi_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
device=device,
oov="<UNK>",
)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
is_bpe=False,
is_espnet_structure=params.is_espnet_structure,
use_feat_batchnorm=params.use_feat_batchnorm,
)
assert model.decoder_num_class == num_classes + 1
params.sos_id = num_classes
params.eos_id = num_classes
checkpoints = load_checkpoint_if_available(params=params, model=model)
logging.info(params)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and checkpoints["optimizer"]:
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

1
egs/librispeech/ASR/local/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
tmp_lang

View File

@ -0,0 +1,110 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
'''
Add silence with a given probability after each word in the transcript.
If the input transcript contains:
hello world
foo bar koo
zoo
Then the output transcript **may** look like the following:
!SIL hello !SIL world !SIL
foo bar !SIL koo !SIL
!SIL zoo !SIL
(Assume !SIL represents silence.)
'''
from pathlib import Path
import argparse
import random
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--transcript',
type=str,
help='The input transcript file.'
'We assume that the transcript file consists of '
'lines. Each line consists of space separated words.')
parser.add_argument('--sil-word',
type=str,
default='!SIL',
help='The word that represents silence.')
parser.add_argument('--sil-prob',
type=float,
default=0.5,
help='The probability for adding a '
'silence after each world.')
parser.add_argument('--seed',
type=int,
default=None,
help='The seed for random number generators.')
return parser.parse_args()
def need_silence(sil_prob: float) -> bool:
'''
Args:
sil_prob:
The probability to add a silence.
Returns:
Return True if a silence is needed.
Return False otherwise.
'''
return random.uniform(0, 1) <= sil_prob
def process_line(line: str, sil_word: str, sil_prob: float) -> None:
'''Process a single line from the transcript.
Args:
line:
A str containing space separated words.
sil_word:
The symbol indicating silence.
sil_prob:
The probability for adding a silence after each word.
Returns:
Return None.
'''
words = line.strip().split()
for i, word in enumerate(words):
if i == 0:
# beginning of the line
if need_silence(sil_prob):
print(sil_word, end=' ')
print(word, end=' ')
if need_silence(sil_prob):
print(sil_word, end=' ')
# end of the line, print a new line
if i == len(words) - 1:
print()
def main():
args = get_args()
random.seed(args.seed)
assert Path(args.transcript).is_file()
assert len(args.sil_word) > 0
assert 0 < args.sil_prob < 1
with open(args.transcript) as f:
for line in f:
process_line(line=line,
sil_word=args.sil_word,
sil_prob=args.sil_prob)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
"""
Convert a transcript file containing words to a corpus file containing tokens
for LM training with the help of a lexicon.
If the lexicon contains phones, the resulting LM will be a phone LM; If the
lexicon contains word pieces, the resulting LM will be a word piece LM.
If a word has multiple pronunciations, the one that appears first in the lexicon
is kept; others are removed.
If the input transcript is:
hello zoo world hello
world zoo
foo zoo world hellO
and if the lexicon is
<UNK> SPN
hello h e l l o 2
hello h e l l o
world w o r l d
zoo z o o
Then the output is
h e l l o 2 z o o w o r l d h e l l o 2
w o r l d z o o
SPN z o o w o r l d SPN
"""
import argparse
from pathlib import Path
from typing import Dict, List
from generate_unique_lexicon import filter_multiple_pronunications
from icefall.lexicon import read_lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transcript",
type=str,
help="The input transcript file."
"We assume that the transcript file consists of "
"lines. Each line consists of space separated words.",
)
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
parser.add_argument(
"--oov", type=str, default="<UNK>", help="The OOV word."
)
return parser.parse_args()
def process_line(
lexicon: Dict[str, List[str]], line: str, oov_token: str
) -> None:
"""
Args:
lexicon:
A dict containing pronunciations. Its keys are words and values
are pronunciations (i.e., tokens).
line:
A line of transcript consisting of space(s) separated words.
oov_token:
The pronunciation of the oov word if a word in `line` is not present
in the lexicon.
Returns:
Return None.
"""
s = ""
words = line.strip().split()
for i, w in enumerate(words):
tokens = lexicon.get(w, oov_token)
s += " ".join(tokens)
s += " "
print(s.strip())
def main():
args = get_args()
assert Path(args.lexicon).is_file()
assert Path(args.transcript).is_file()
assert len(args.oov) > 0
# Only the first pronunciation of a word is kept
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
lexicon = dict(lexicon)
assert args.oov in lexicon
oov_token = lexicon[args.oov]
with open(args.transcript) as f:
for line in f:
process_line(lexicon=lexicon, line=line, oov_token=oov_token)
if __name__ == "__main__":
main()

View File

@ -25,7 +25,7 @@ This file downloads the following LibriSpeech LM files:
- librispeech-lexicon.txt
from http://www.openslr.org/resources/11
and save them in the user provided directory.
and saves them in the user provided directory.
Files are not re-downloaded if they already exist.

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file takes as input a lexicon.txt and output a new lexicon,
in which each word has a unique pronunciation.
The way to do this is to keep only the first pronunciation of a word
in lexicon.txt.
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
from icefall.lexicon import read_lexicon, write_lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
This file will generate a new file uniq_lexicon.txt
in it.
""",
)
return parser.parse_args()
def filter_multiple_pronunications(
lexicon: List[Tuple[str, List[str]]]
) -> List[Tuple[str, List[str]]]:
"""Remove multiple pronunciations of words from a lexicon.
If a word has more than one pronunciation in the lexicon, only
the first one is kept, while other pronunciations are removed
from the lexicon.
Args:
lexicon:
The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
where "p1, p2, ..., pn" are the pronunciations of the "word".
Returns:
Return a new lexicon where each word has a unique pronunciation.
"""
seen = set()
ans = []
for word, tokens in lexicon:
if word in seen:
continue
seen.add(word)
ans.append((word, tokens))
return ans
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
in_lexicon = read_lexicon(lexicon_filename)
out_lexicon = filter_multiple_pronunications(in_lexicon)
write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -17,8 +17,9 @@
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
This script takes as input a `lang_dir`, which is expected to contain
a lexicon file "lexicon.txt" consisting of words and tokens (i.e., phones)
and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
@ -32,12 +33,16 @@ consisting of words and tokens (i.e., phones) and does the following:
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
The generated files are saved into `lang_dir`.
"""
import math
import argparse
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
from icefall.utils import str2bool
import k2
import torch
@ -46,6 +51,34 @@ from icefall.lexicon import read_lexicon, write_lexicon
Lexicon = List[Tuple[str, List[str]]]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
Generated files by this script are saved into this directory.
""",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
See "local/test_prepare_lang.sh" for usage.
""",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
@ -69,6 +102,10 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
If pronunciations are phones, then tokens are phones.
If pronunciations are word pieces, then tokens are word pieces.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
@ -192,6 +229,9 @@ def add_self_loops(
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Caution:
Don't be confused with :func:`k2.add_epsilon_self_loops`.
Args:
arcs:
A list-of-list. The sublist contains
@ -221,12 +261,9 @@ def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
@ -235,11 +272,6 @@ def lexicon_to_fst(
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
@ -247,50 +279,43 @@ def lexicon_to_fst(
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert token2id["<eps>"] == 0
if "<blk>" in token2id:
# For BPE based lexicon
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
else:
# For phone based lexicon in the CTC topo,
# 0 on the left side (i.e., as label) indicates a blank.
# 0 on the right side (i.e., as aux_label) represents an epsilon
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [token2id[i] for i in tokens]
pieces = [token2id[i] for i in pieces]
for i in range(len(tokens) - 1):
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
@ -315,10 +340,10 @@ def lexicon_to_fst(
def main():
out_dir = Path("data/lang_phone")
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
args = get_args()
lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon)
@ -344,37 +369,36 @@ def main():
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / "tokens.txt", token2id)
write_mapping(out_dir / "words.txt", word2id)
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
write_mapping(lang_dir / "tokens.txt", token2id)
write_mapping(lang_dir / "words.txt", word2id)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__":

View File

@ -44,86 +44,12 @@ import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
lexicon_to_fst,
write_lexicon,
write_mapping,
)
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def generate_lexicon(
model_file: str, words: List[str]
) -> Tuple[Lexicon, Dict[str, int]]:
@ -206,13 +132,13 @@ def main():
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
L = lexicon_to_fst(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,

View File

@ -1,106 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import os
import tempfile
import k2
from prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,
get_words,
lexicon_to_fst,
read_lexicon,
write_lexicon,
write_mapping,
)
def generate_lexicon_file() -> str:
fd, filename = tempfile.mkstemp()
os.close(fd)
s = """
!SIL SIL
<SPOKEN_NOISE> SPN
<UNK> SPN
f f
a a
foo f o o
bar b a r
bark b a r k
food f o o d
food2 f o o d
fo f o
""".strip()
with open(filename, "w") as f:
f.write(s)
return filename
def test_read_lexicon(filename: str):
lexicon = read_lexicon(filename)
phones = get_phones(lexicon)
words = get_words(lexicon)
print(lexicon)
print(phones)
print(words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
print(lexicon_disambig)
print("max disambig:", f"#{max_disambig}")
phones = ["<eps>", "SIL", "SPN"] + phones
for i in range(max_disambig + 1):
phones.append(f"#{i}")
words = ["<eps>"] + words
phone2id = generate_id_map(phones)
word2id = generate_id_map(words)
print(phone2id)
print(word2id)
write_mapping("phones.txt", phone2id)
write_mapping("words.txt", word2id)
write_lexicon("a.txt", lexicon)
write_lexicon("a_disambig.txt", lexicon_disambig)
fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
def main():
filename = generate_lexicon_file()
test_read_lexicon(filename)
os.remove(filename)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,37 @@
#!/usr/bin/env bash
lang_dir=tmp_lang
mkdir -p $lang_dir
cat <<EOF > $lang_dir/lexicon.txt
<UNK> SPN
f f
a a
foo f o o
bar b a r
bark b a r k
food f o o d
food2 f o o d
fo f o
fo f o o
EOF
./prepare_lang.py --lang-dir $lang_dir --debug 1
./generate_unique_lexicon.py --lang-dir $lang_dir
cat <<EOF > $lang_dir/transcript_words.txt
foo bar bark food food2 fo f a foo bar
bar food2 fo bark
EOF
./convert_transcript_words_to_tokens.py \
--lexicon $lang_dir/uniq_lexicon.txt \
--transcript $lang_dir/transcript_words.txt \
--oov "<UNK>" \
> $lang_dir/transcript_tokens.txt
../shared/make_kn_lm.py \
-ngram-order 2 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/P.arpa
echo "Please delete the directory '$lang_dir' manually"

View File

@ -38,7 +38,7 @@ def get_args():
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the training corpus: train.txt.
It should contain the training corpus: transcript_words.txt.
The generated bpe.model is saved to this directory.
""",
)
@ -59,7 +59,7 @@ def main():
model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = f"{lang_dir}/train.txt"
train_text = f"{lang_dir}/transcript_words.txt"
character_coverage = 1.0
input_sentence_size = 100000000

View File

@ -114,14 +114,53 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
mkdir -p data/lang_phone
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
echo '<UNK> SPN' |
cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > data/lang_phone/lexicon.txt
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f data/lang_phone/L_disambig.pt ]; then
./local/prepare_lang.py
./local/generate_unique_lexicon.py --lang-dir $lang_dir
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
# Train a bigram P for MMI training
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data to train phone based bigram P"
files=$(
find -L "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
find -L "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
find -L "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $lang_dir/transcript_words.txt
fi
if [ ! -f $lang_dir/transcript_tokens.txt ]; then
./local/convert_transcript_words_to_tokens.py \
--lexicon $lang_dir/uniq_lexicon.txt \
--transcript $lang_dir/transcript_words.txt \
--oov "<UNK>" \
> $lang_dir/transcript_tokens.txt
fi
if [ ! -f $lang_dir/P.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 2 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/P.arpa
fi
if [ ! -f $lang_dir/P.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
$lang_dir/P.arpa > $lang_dir/P.fst.txt
fi
fi
@ -135,7 +174,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/train.txt ]; then
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
files=$(
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -144,7 +183,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $lang_dir/train.txt
done > $lang_dir/transcript_words.txt
fi
./local/train_bpe_model.py \
@ -159,7 +198,7 @@ fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
@ -192,4 +231,4 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
done
fi
cd data && ln -sfv lang_bpe_5000 lang_bpe
# cd data && ln -sfv lang_bpe_5000 lang_bpe

View File

@ -84,6 +84,68 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
f.write(f"{word} {' '.join(tokens)}\n")
def convert_lexicon_to_ragged(
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
) -> k2.RaggedTensor:
"""Read a lexicon and convert it to a ragged tensor.
Caution:
We assume that each word has a unique pronunciation.
Args:
filename:
Filename of the lexicon. It has a format that can be read
by :func:`read_lexicon`.
word_table:
The word symbol table.
token_table:
The token symbol table.
Returns:
A k2 ragged tensor with two axes [word_id][token_id]
"""
disambig_id = word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
# epsilon is not a word, but it occupies a position
#
row_splits = [0]
token_ids_list = []
lexicon_tmp = read_lexicon(filename)
lexicon = dict(lexicon_tmp)
if len(lexicon_tmp) != len(lexicon):
raise RuntimeError(
"It's assumed that each word has a unique pronunciation"
)
for i in range(disambig_id):
w = word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
tokens = lexicon[w]
token_ids = [token_table[k] for k in tokens]
row_splits.append(row_splits[-1] + len(token_ids))
token_ids_list.extend(token_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
# row_splits=row_splits, cached_tot_size=cached_tot_size
row_splits,
None,
cached_tot_size,
)
values = torch.tensor(token_ids_list, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
class Lexicon(object):
"""Phone based lexicon."""
@ -95,7 +157,7 @@ class Lexicon(object):
"""
Args:
lang_dir:
Path to the lang director. It is expected to contain the following
Path to the lang directory. It is expected to contain the following
files:
- tokens.txt
- words.txt
@ -119,7 +181,7 @@ class Lexicon(object):
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
# We save L_inv instead of L because it will be used to intersect with
# transcript, both of whose labels are word IDs.
# transcript FSAs, both of whose labels are word IDs.
self.L_inv = L_inv
self.disambig_pattern = disambig_pattern
@ -142,70 +204,61 @@ class Lexicon(object):
return ans
class BpeLexicon(Lexicon):
class UniqLexicon(Lexicon):
def __init__(
self,
lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Refer to the help information in Lexicon.__init__.
uniq_filename: It is assumed to be inside the given `lang_dir`.
Each word in the lexicon is assumed to have a unique pronunciation.
"""
lang_dir = Path(lang_dir)
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
self.ragged_lexicon = self.convert_lexicon_to_ragged(
lang_dir / "lexicon.txt"
self.ragged_lexicon = convert_lexicon_to_ragged(
filename=lang_dir / uniq_filename,
word_table=self.word_table,
token_table=self.token_table,
)
# TODO: should we move it to a certain device ?
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
"""Read a BPE lexicon from file and convert it to a
k2 ragged tensor.
def texts_to_token_ids(
self, texts: List[str], oov: str = "<UNK>"
) -> k2.RaggedTensor:
"""
Args:
filename:
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt
texts:
A list of transcripts. Each transcript contains space(s)
separated words. An example texts is::
['HELLO k2', 'HELLO icefall']
oov:
The OOV word. If a word in `texts` is not in the lexicon, it is
replaced with `oov`.
Returns:
A k2 ragged tensor with two axes [word_id]
Return a ragged int tensor with 2 axes [utterance][token_id]
"""
disambig_id = self.word_table["#0"]
# We reuse the same words.txt from the phone based lexicon
# so that we can share the same G.fst. Here, we have to
# exclude some words present only in the phone based lexicon.
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
oov_id = self.word_table[oov]
# epsilon is not a word, but it occupies on position
#
row_splits = [0]
token_ids = []
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
return self.ragged_lexicon.index(ragged_indexes, remove_axis=True)
lexicon = read_lexicon(filename)
lexicon = dict(lexicon)
for i in range(disambig_id):
w = self.word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
pieces = lexicon[w]
piece_ids = [self.token_table[k] for k in pieces]
row_splits.append(row_splits[-1] + len(piece_ids))
token_ids.extend(piece_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=cached_tot_size
)
values = torch.tensor(token_ids, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor contained
word piece IDs.
"""
def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
"""Convert a list of words to a ragged tensor containing token IDs."""
word_ids = [self.word_table[w] for w in words]
word_ids = torch.tensor(word_ids, dtype=torch.int32)

233
icefall/mmi.py Normal file
View File

@ -0,0 +1,233 @@
from typing import List
import k2
import torch
from torch import nn
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
def _compute_mmi_loss_exact_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
The function name contains `exact`, which means it uses a version of
intersection without pruning.
`optimized` in the function name means this function is optimized
in that it calls k2.intersect_dense only once
Note:
It is faster at the cost of using more memory.
Args:
dense_fsa_vec:
It contains the neural network output.
texts:
The transcript. Each element consists of space(s) separated words.
graph_compiler:
Used to build num_graphs and den_graphs
den_scale:
The scale applied to the denominator tot_scores.
Returns:
Return a scalar loss. It is the sum over utterances in a batch,
without normalization.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
device = num_graphs.device
num_fsas = num_graphs.shape[0]
assert dense_fsa_vec.dim0() == num_fsas
assert den_graphs.shape[0] == 1
# The motivation to concatenate num_graphs and den_graphs
# is to reduce the number of calls to k2.intersect_dense.
num_den_graphs = k2.cat([num_graphs, den_graphs])
# NOTE: The a_to_b_map in k2.intersect_dense must be sorted
# so the following reorders num_den_graphs.
#
# The following code computes a_to_b_map
# [0, 1, 2, ... ]
num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)
# [num_fsas, num_fsas, num_fsas, ... ]
den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)
# [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
num_den_graphs_indexes = (
torch.stack([num_graphs_indexes, den_graphs_indexes])
.t()
.reshape(-1)
.to(device)
)
num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
# [[0, 1, 2, ...]]
a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)
# [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)
num_den_lats = k2.intersect_dense(
num_den_reordered_graphs,
dense_fsa_vec,
output_beam=output_beam,
a_to_b_map=a_to_b_map,
)
num_den_tot_scores = num_den_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
num_tot_scores = num_den_tot_scores[::2]
den_tot_scores = num_den_tot_scores[1::2]
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
def _compute_mmi_loss_exact_non_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
of the arguments.
It's more readable, though it invokes k2.intersect_dense twice.
Note:
It uses less memory at the cost of speed. It is slower.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
num_lats = k2.intersect_dense(
num_graphs, dense_fsa_vec, output_beam=output_beam
)
den_lats = k2.intersect_dense(
den_graphs, dense_fsa_vec, output_beam=output_beam
)
num_tot_scores = num_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
den_tot_scores = den_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
def _compute_mmi_loss_pruned(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
of the arguments.
`pruned` means it uses k2.intersect_dense_pruned
Note:
It uses the least amount of memory, but the loss is not exact due
to pruning.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
num_lats = k2.intersect_dense(
num_graphs, dense_fsa_vec, output_beam=output_beam
)
# the values for search_beam/output_beam/min_active_states/max_active_states
# are not tuned. You may want to tune them.
den_lats = k2.intersect_dense_pruned(
den_graphs,
dense_fsa_vec,
search_beam=20.0,
output_beam=output_beam,
min_active_states=30,
max_active_states=10000,
)
num_tot_scores = num_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
den_tot_scores = den_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
class LFMMILoss(nn.Module):
"""
Computes Lattice-Free Maximum Mutual Information (LFMMI) loss.
TODO: more detailed description
"""
def __init__(
self,
graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
use_pruned_intersect: bool = False,
den_scale: float = 1.0,
):
super().__init__()
self.graph_compiler = graph_compiler
self.output_beam = output_beam
self.use_pruned_intersect = use_pruned_intersect
self.den_scale = den_scale
def forward(
self,
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
) -> torch.Tensor:
"""
Args:
dense_fsa_vec:
It contains the neural network output.
texts:
A list of strings. Each string contains space(s) separated words.
Returns:
Return a scalar loss. It is the sum over utterances in a batch,
without normalization.
"""
if self.use_pruned_intersect:
func = _compute_mmi_loss_pruned
else:
func = _compute_mmi_loss_exact_non_optimized
# func = _compute_mmi_loss_exact_optimized
return func(
dense_fsa_vec=dense_fsa_vec,
texts=texts,
graph_compiler=self.graph_compiler,
output_beam=self.output_beam,
den_scale=self.den_scale,
)

View File

@ -0,0 +1,207 @@
import logging
from pathlib import Path
from typing import Iterable, List, Tuple, Union
import k2
import torch
from icefall.lexicon import UniqLexicon
class MmiTrainingGraphCompiler(object):
def __init__(
self,
lang_dir: Path,
uniq_filename: str = "uniq_lexicon.txt",
device: Union[str, torch.device] = "cpu",
oov: str = "<UNK>",
):
"""
Args:
lang_dir:
Path to the lang directory. It is expected to contain the
following files::
- tokens.txt
- words.txt
- P.fst.txt
The above files are generated by the script `prepare.sh`. You
should have run it before running the training code.
uniq_filename:
File name to the lexicon in which every word has exactly one
pronunciation. We assume this file is inside the given `lang_dir`.
device:
It indicates CPU or CUDA.
oov:
Out of vocabulary word. When a word in the transcript
does not exist in the lexicon, it is replaced with `oov`.
"""
self.lang_dir = Path(lang_dir)
self.lexicon = UniqLexicon(lang_dir)
self.device = torch.device(device)
self.L_inv = self.lexicon.L_inv.to(self.device)
self.oov_id = self.lexicon.word_table[oov]
self.build_ctc_topo_P()
def build_ctc_topo_P(self):
"""Built ctc_topo_P, the composition result of
ctc_topo and P, where P is a pre-trained bigram
word piece LM.
"""
# Note: there is no need to save a pre-compiled P and ctc_topo
# as it is very fast to generate them.
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
with open(self.lang_dir / "P.fst.txt") as f:
# P is not an acceptor because there is
# a back-off state, whose incoming arcs
# have label #0 and aux_label 0 (i.e., <eps>).
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = self.lexicon.token_table["#0"]
# P.aux_labels is not needed in later computations, so
# remove it here.
del P.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
P.labels[P.labels >= first_token_disambig_id] = 0
P = k2.remove_epsilon(P)
P = k2.arc_sort(P)
P = P.to(self.device)
# Add epsilon self-loops to P because we want the
# following operation "k2.intersect" to run on GPU.
P_with_self_loops = k2.add_epsilon_self_loops(P)
max_token_id = max(self.lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device)
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
logging.info("Building ctc_topo_P")
ctc_topo_P = k2.intersect(
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
).invert()
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
def compile(
self, texts: Iterable[str], replicate_den: bool = True
) -> Tuple[k2.Fsa, k2.Fsa]:
"""Create numerator and denominator graphs from transcripts
and the bigram phone LM.
Args:
texts:
A list of transcripts. Within a transcript, words are
separated by spaces.
replicate_den:
If True, the returned den_graph is replicated to match the number
of FSAs in the returned num_graph; if False, the returned den_graph
contains only a single FSA
Returns:
A tuple (num_graph, den_graph), where
- `num_graph` is the numerator graph. It is an FsaVec with
shape `(len(texts), None, None)`.
- `den_graph` is the denominator graph. It is an FsaVec
with the same shape of the `num_graph` if replicate_den is
True; otherwise, it is an FsaVec containing only a single FSA.
"""
transcript_fsa = self.build_transcript_fsa(texts)
# remove word IDs from transcript_fsa since it is not needed
del transcript_fsa.aux_labels
transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
transcript_fsa
)
transcript_fsa_with_self_loops = k2.arc_sort(
transcript_fsa_with_self_loops
)
num = k2.compose(
self.ctc_topo_P,
transcript_fsa_with_self_loops,
treat_epsilons_specially=False,
)
# CAUTION: Due to the presence of P,
# the resulting `num` may not be connected
num = k2.connect(num)
num = k2.arc_sort(num)
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
if replicate_den:
indexes = torch.zeros(
len(texts), dtype=torch.int32, device=self.device
)
den = k2.index_fsa(ctc_topo_P_vec, indexes)
else:
den = ctc_topo_P_vec
return num, den
def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa:
"""Convert transcripts to an FsaVec with the help of a lexicon
and word symbol table.
Args:
texts:
Each element is a transcript containing words separated by space(s).
For instance, it may be 'HELLO icefall', which contains
two words.
Returns:
Return an FST (FsaVec) corresponding to the transcript.
Its `labels` is token IDs and `aux_labels` is word IDs.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.lexicon.word_table:
word_ids.append(self.lexicon.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
fsa = k2.linear_fsa(word_ids_list, self.device)
fsa = k2.add_epsilon_self_loops(fsa)
# The reason to use `invert_()` at the end is as follows:
#
# (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs
# (2) `fsa.labels` is word IDs
# (3) after intersection, the `labels` is still word IDs
# (4) after `invert_()`, the `labels` is token IDs
# and `aux_labels` is word IDs
transcript_fsa = k2.intersect(
self.L_inv, fsa, treat_epsilons_specially=False
).invert_()
transcript_fsa = k2.arc_sort(transcript_fsa)
return transcript_fsa
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of piece IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of token IDs.
"""
return self.lexicon.texts_to_token_ids(texts).tolist()

377
icefall/shared/make_kn_lm.py Executable file
View File

@ -0,0 +1,377 @@
#!/usr/bin/env python3
# Copyright 2016 Johns Hopkins University (Author: Daniel Povey)
# 2018 Ruizhe Huang
# Apache 2.0.
# This is an implementation of computing Kneser-Ney smoothed language model
# in the same way as srilm. This is a back-off, unmodified version of
# Kneser-Ney smoothing, which produces the same results as the following
# command (as an example) of srilm:
#
# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \
# -text corpus.txt -lm lm.arpa
#
# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
import sys
import os
import re
import io
import math
import argparse
from collections import Counter, defaultdict
parser = argparse.ArgumentParser(description="""
Generate kneser-ney language model as arpa format. By default,
it will read the corpus from standard input, and output to standard output.
""")
parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
args = parser.parse_args()
default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input.
# Need to be very careful about the use of strip() and split()
# in this case, because there is a latin-1 whitespace character
# (nbsp) which is part of the unicode encoding range.
# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
strip_chars = " \t\r\n"
whitespace = re.compile("[ \t]+")
class CountsForHistory:
# This class (which is more like a struct) stores the counts seen in a
# particular history-state. It is used inside class NgramCounts.
# It really does the job of a dict from int to float, but it also
# keeps track of the total count.
def __init__(self):
# The 'lambda: defaultdict(float)' is an anonymous function taking no
# arguments that returns a new defaultdict(float).
self.word_to_count = defaultdict(int)
self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts
self.word_to_f = dict() # discounted probability
self.word_to_bow = dict() # back-off weight
self.total_count = 0
def words(self):
return self.word_to_count.keys()
def __str__(self):
# e.g. returns ' total=12: 3->4, 4->6, -1->2'
return ' total={0}: {1}'.format(
str(self.total_count),
', '.join(['{0} -> {1}'.format(word, count)
for word, count in self.word_to_count.items()]))
def add_count(self, predicted_word, context_word, count):
assert count >= 0
self.total_count += count
self.word_to_count[predicted_word] += count
if context_word is not None:
self.word_to_context[predicted_word].add(context_word)
class NgramCounts:
# A note on data-structure. Firstly, all words are represented as
# integers. We store n-gram counts as an array, indexed by (history-length
# == n-gram order minus one) (note: python calls arrays "lists") of dicts
# from histories to counts, where histories are arrays of integers and
# "counts" are dicts from integer to float. For instance, when
# accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
# do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
# array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
def __init__(self, ngram_order, bos_symbol='<s>', eos_symbol='</s>'):
assert ngram_order >= 2
self.ngram_order = ngram_order
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
self.counts = []
for n in range(ngram_order):
self.counts.append(defaultdict(lambda: CountsForHistory()))
self.d = [] # list of discounting factor for each order of ngram
# adds a raw count (called while processing input data).
# Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history'
# would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
# 1.
def add_count(self, history, predicted_word, context_word, count):
self.counts[len(history)][history].add_count(predicted_word, context_word, count)
# 'line' is a string containing a sequence of integer word-ids.
# This function adds the un-smoothed counts from this line of text.
def add_raw_counts_from_line(self, line):
if line == '':
words = [self.bos_symbol, self.eos_symbol]
else:
words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
for i in range(len(words)):
for n in range(1, self.ngram_order+1):
if i + n > len(words):
break
ngram = words[i: i + n]
predicted_word = ngram[-1]
history = tuple(ngram[: -1])
if i == 0 or n == self.ngram_order:
context_word = None
else:
context_word = words[i-1]
self.add_count(history, predicted_word, context_word, 1)
def add_raw_counts_from_standard_input(self):
lines_processed = 0
infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input
for line in infile:
line = line.strip(strip_chars)
self.add_raw_counts_from_line(line)
lines_processed += 1
if lines_processed == 0 or args.verbose > 0:
print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
def add_raw_counts_from_file(self, filename):
lines_processed = 0
with open(filename, encoding=default_encoding) as fp:
for line in fp:
line = line.strip(strip_chars)
self.add_raw_counts_from_line(line)
lines_processed += 1
if lines_processed == 0 or args.verbose > 0:
print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
def cal_discounting_constants(self):
# For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
# where n1_N is the number of unique N-grams with count = 1 (counts-of-counts).
# This constant is used similarly to absolute discounting.
# Return value: d is a list of floats, where d[N+1] = D_N
self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
# This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
# but perhaps this is not the case for some other scenarios.
for n in range(1, self.ngram_order):
this_order_counts = self.counts[n]
n1 = 0
n2 = 0
for hist, counts_for_hist in this_order_counts.items():
stat = Counter(counts_for_hist.word_to_count.values())
n1 += stat[1]
n2 += stat[2]
assert n1 + 2 * n2 > 0
self.d.append(n1 * 1.0 / (n1 + 2 * n2))
def cal_f(self):
# f(a_z) is a probability distribution of word sequence a_z.
# Typically f(a_z) is discounted to be less than the ML estimate so we have
# some leftover probability for the z words unseen in the context (a_).
#
# f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams
# f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams
# highest order N-grams
n = self.ngram_order - 1
this_order_counts = self.counts[n]
for hist, counts_for_hist in this_order_counts.items():
for w, c in counts_for_hist.word_to_count.items():
counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
# lower order N-grams
for n in range(0, self.ngram_order - 1):
this_order_counts = self.counts[n]
for hist, counts_for_hist in this_order_counts.items():
n_star_star = 0
for w in counts_for_hist.word_to_count.keys():
n_star_star += len(counts_for_hist.word_to_context[w])
if n_star_star != 0:
for w in counts_for_hist.word_to_count.keys():
n_star_z = len(counts_for_hist.word_to_context[w])
counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
else: # patterns begin with <s>, they do not have "modified count", so use raw count instead
for w in counts_for_hist.word_to_count.keys():
n_star_z = counts_for_hist.word_to_count[w]
counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
def cal_bow(self):
# Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
# Thus, two sorts of ngrams do not have a bow:
# 1) highest order ngram
# 2) ngrams ending in </s>
#
# bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z))
# Note that Z1 is the set of all words with c(a_z) > 0
# highest order N-grams
n = self.ngram_order - 1
this_order_counts = self.counts[n]
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
counts_for_hist.word_to_bow[w] = None
# lower order N-grams
for n in range(0, self.ngram_order - 1):
this_order_counts = self.counts[n]
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
if w == self.eos_symbol:
counts_for_hist.word_to_bow[w] = None
else:
a_ = hist + (w,)
assert len(a_) < self.ngram_order
assert a_ in self.counts[len(a_)].keys()
a_counts_for_hist = self.counts[len(a_)][a_]
sum_z1_f_a_z = 0
for u in a_counts_for_hist.word_to_count.keys():
sum_z1_f_a_z += a_counts_for_hist.word_to_f[u]
sum_z1_f_z = 0
_ = a_[1:]
_counts_for_hist = self.counts[len(_)][_]
for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1
sum_z1_f_z += _counts_for_hist.word_to_f[u]
counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
def print_raw_counts(self, info_string):
# these are useful for debug.
print(info_string)
res = []
for this_order_counts in self.counts:
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
ngram = " ".join(hist) + " " + w
ngram = ngram.strip(strip_chars)
res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
res.sort(reverse=True)
for r in res:
print(r)
def print_modified_counts(self, info_string):
# these are useful for debug.
print(info_string)
res = []
for this_order_counts in self.counts:
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
ngram = " ".join(hist) + " " + w
ngram = ngram.strip(strip_chars)
modified_count = len(counts_for_hist.word_to_context[w])
raw_count = counts_for_hist.word_to_count[w]
if modified_count == 0:
res.append("{0}\t{1}".format(ngram, raw_count))
else:
res.append("{0}\t{1}".format(ngram, modified_count))
res.sort(reverse=True)
for r in res:
print(r)
def print_f(self, info_string):
# these are useful for debug.
print(info_string)
res = []
for this_order_counts in self.counts:
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
ngram = " ".join(hist) + " " + w
ngram = ngram.strip(strip_chars)
f = counts_for_hist.word_to_f[w]
if f == 0: # f(<s>) is always 0
f = 1e-99
res.append("{0}\t{1}".format(ngram, math.log(f, 10)))
res.sort(reverse=True)
for r in res:
print(r)
def print_f_and_bow(self, info_string):
# these are useful for debug.
print(info_string)
res = []
for this_order_counts in self.counts:
for hist, counts_for_hist in this_order_counts.items():
for w in counts_for_hist.word_to_count.keys():
ngram = " ".join(hist) + " " + w
ngram = ngram.strip(strip_chars)
f = counts_for_hist.word_to_f[w]
if f == 0: # f(<s>) is always 0
f = 1e-99
bow = counts_for_hist.word_to_bow[w]
if bow is None:
res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
else:
res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
res.sort(reverse=True)
for r in res:
print(r)
def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
# print as ARPA format.
print('\\data\\', file=fout)
for hist_len in range(self.ngram_order):
# print the number of n-grams.
print('ngram {0}={1}'.format(
hist_len + 1,
sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
file=fout
)
print('', file=fout)
for hist_len in range(self.ngram_order):
print('\\{0}-grams:'.format(hist_len + 1), file=fout)
this_order_counts = self.counts[hist_len]
for hist, counts_for_hist in this_order_counts.items():
for word in counts_for_hist.word_to_count.keys():
ngram = hist + (word,)
prob = counts_for_hist.word_to_f[word]
bow = counts_for_hist.word_to_bow[word]
if prob == 0: # f(<s>) is always 0
prob = 1e-99
line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
if bow is not None:
line += '\t{0}'.format('%.7f' % math.log10(bow))
print(line, file=fout)
print('', file=fout)
print('\\end\\', file=fout)
if __name__ == "__main__":
ngram_counts = NgramCounts(args.ngram_order)
if args.text is None:
ngram_counts.add_raw_counts_from_standard_input()
else:
assert os.path.isfile(args.text)
ngram_counts.add_raw_counts_from_file(args.text)
ngram_counts.cal_discounting_constants()
ngram_counts.cal_f()
ngram_counts.cal_bow()
if args.lm is None:
ngram_counts.print_as_arpa()
else:
with open(args.lm, 'w', encoding=default_encoding) as f:
ngram_counts.print_as_arpa(fout=f)

View File

@ -4,8 +4,11 @@ profile = "black"
[tool.black]
line-length = 80
exclude = '''
/(
\.git
| \.github
)/
(
/(
\.git
| \.github
| icefall/shared/*
)/
)
'''

View File

@ -19,20 +19,21 @@
from pathlib import Path
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import BpeLexicon
from icefall.lexicon import UniqLexicon
ICEFALL_DIR = Path(__file__).resolve().parent.parent
def test():
lang_dir = Path("data/lang/bpe")
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe"
if not lang_dir.is_dir():
return
# TODO: generate data for testing
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
compiler.compile(ids)
lexicon = BpeLexicon(lang_dir)
lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt")
ids0 = lexicon.words_to_piece_ids(["HELLO"])
assert ids[0] == ids0.values().tolist()

124
test/test_lexicon.py Normal file → Executable file
View File

@ -14,80 +14,84 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
You can run this file in one of the two ways:
(1) cd icefall; pytest test/test_lexicon.py
(2) cd icefall; ./test/test_lexicon.py
"""
import os
import shutil
import sys
from pathlib import Path
import k2
import pytest
import torch
from icefall.lexicon import BpeLexicon, Lexicon
from icefall.lexicon import UniqLexicon
TMP_DIR = "/tmp/icefall-test-lexicon"
USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent
@pytest.fixture
def lang_dir(tmp_path):
phone2id = """
<eps> 0
a 1
b 2
f 3
o 4
r 5
z 6
SPN 7
#0 8
"""
word2id = """
<eps> 0
foo 1
bar 2
baz 3
<UNK> 4
#0 5
"""
def generate_test_data():
# if Path(TMP_DIR).exists():
# return
Path(TMP_DIR).mkdir(exist_ok=True)
lexicon = """
<UNK> SPN
cat c a t
at a t
at a a t
ac a c
ac a c c
"""
lexicon_filename = Path(TMP_DIR) / "lexicon.txt"
with open(lexicon_filename, "w") as f:
for line in lexicon.strip().split("\n"):
f.write(f"{line}\n")
L = k2.Fsa.from_str(
"""
0 0 7 4 0
0 7 -1 -1 0
0 1 3 1 0
0 3 2 2 0
0 5 2 3 0
1 2 4 0 0
2 0 4 0 0
3 4 1 0 0
4 0 5 0 0
5 6 1 0 0
6 0 6 0 0
7
""",
num_aux_labels=1,
os.system(
f"""
cd {ICEFALL_DIR}/egs/librispeech/ASR
./local/generate_unique_lexicon.py --lang-dir {TMP_DIR}
./local/prepare_lang.py --lang-dir {TMP_DIR}
"""
)
with open(tmp_path / "tokens.txt", "w") as f:
f.write(phone2id)
with open(tmp_path / "words.txt", "w") as f:
f.write(word2id)
torch.save(L.as_dict(), tmp_path / "L.pt")
return tmp_path
def delete_test_data():
shutil.rmtree(TMP_DIR)
def test_lexicon(lang_dir):
lexicon = Lexicon(lang_dir)
assert lexicon.tokens == list(range(1, 8))
def uniq_lexicon_test():
lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="uniq_lexicon.txt")
texts = ["cat cat", "at ac", "ca at cat"]
token_ids = lexicon.texts_to_token_ids(texts)
#
# c a t c a t a t a 3 SPN a t c a t
expected_ids = [[3, 2, 4, 3, 2, 4], [2, 4, 2, 3], [1, 2, 4, 3, 2, 4]]
expected_ids = k2.RaggedTensor(expected_ids)
assert token_ids == expected_ids
def test_bpe_lexicon():
lang_dir = Path("data/lang/bpe")
if not lang_dir.is_dir():
return
# TODO: Generate test data for BpeLexicon
def test_main():
generate_test_data()
lexicon = BpeLexicon(lang_dir)
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
ids = lexicon.words_to_piece_ids(words)
print(ids)
print([lexicon.token_table[i] for i in ids.values().tolist()])
uniq_lexicon_test()
if USING_PYTEST:
delete_test_data()
def main():
test_main()
if __name__ == "__main__" and not USING_PYTEST:
main()

179
test/test_mmi_graph_compiler.py Executable file
View File

@ -0,0 +1,179 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
You can run this file in one of the two ways:
(1) cd icefall; pytest test/test_mmi_graph_compiler.py
(2) cd icefall; ./test/test_mmi_graph_compiler.py
"""
import copy
import os
import shutil
import sys
from pathlib import Path
import k2
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent
def generate_test_data():
# if Path(TMP_DIR).exists():
# return
Path(TMP_DIR).mkdir(exist_ok=True)
lexicon = """
<UNK> SPN
cat c a t
at a t
at a a t
ac a c
ac a c c
"""
lexicon_filename = Path(TMP_DIR) / "lexicon.txt"
with open(lexicon_filename, "w") as f:
for line in lexicon.strip().split("\n"):
f.write(f"{line}\n")
transcript_words = """
cat at ta
at at cat ta
"""
transcript_words_filename = Path(TMP_DIR) / "transcript_words.txt"
with open(transcript_words_filename, "w") as f:
for line in transcript_words.strip().split("\n"):
f.write(f"{line}\n")
os.system(
f"""
cd {ICEFALL_DIR}/egs/librispeech/ASR
./local/generate_unique_lexicon.py --lang-dir {TMP_DIR}
./local/prepare_lang.py --lang-dir {TMP_DIR}
./local/convert_transcript_words_to_tokens.py \
--lexicon {TMP_DIR}/uniq_lexicon.txt \
--transcript {TMP_DIR}/transcript_words.txt \
--oov "<UNK>" \
> {TMP_DIR}/transcript_tokens.txt
shared/make_kn_lm.py \
-ngram-order 2 \
-text {TMP_DIR}/transcript_tokens.txt \
-lm {TMP_DIR}/P.arpa
python3 -m kaldilm \
--read-symbol-table="{TMP_DIR}/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
{TMP_DIR}/P.arpa > {TMP_DIR}/P.fst.txt
"""
)
def delete_test_data():
shutil.rmtree(TMP_DIR)
def mmi_graph_compiler_test():
graph_compiler = MmiTrainingGraphCompiler(lang_dir=TMP_DIR)
print(graph_compiler.device)
L_inv = graph_compiler.L_inv
L = k2.invert(L_inv)
L.labels_sym = graph_compiler.lexicon.token_table
L.aux_labels_sym = graph_compiler.lexicon.word_table
L.draw(f"{TMP_DIR}/L.svg", title="L")
L_inv.labels_sym = graph_compiler.lexicon.word_table
L_inv.aux_labels_sym = graph_compiler.lexicon.token_table
L_inv.draw(f"{TMP_DIR}/L_inv.svg", title="L")
ctc_topo_P = graph_compiler.ctc_topo_P
ctc_topo_P.labels_sym = copy.deepcopy(graph_compiler.lexicon.token_table)
ctc_topo_P.labels_sym._id2sym[0] = "<blk>"
ctc_topo_P.labels_sym._sym2id["<blk>"] = 0
ctc_topo_P.aux_labels_sym = graph_compiler.lexicon.token_table
ctc_topo_P.draw(f"{TMP_DIR}/ctc_topo_P.svg", title="ctc_topo_P")
print(ctc_topo_P.num_arcs)
print(k2.connect(ctc_topo_P).num_arcs)
with open(str(TMP_DIR) + "/P.fst.txt") as f:
# P is not an acceptor because there is
# a back-off state, whose incoming arcs
# have label #0 and aux_label 0 (i.e., <eps>).
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
P.labels_sym = graph_compiler.lexicon.token_table
P.aux_labels_sym = graph_compiler.lexicon.token_table
P.draw(f"{TMP_DIR}/P.svg", title="P")
ctc_topo = k2.ctc_topo(max(graph_compiler.lexicon.tokens), False)
ctc_topo.labels_sym = ctc_topo_P.labels_sym
ctc_topo.aux_labels_sym = graph_compiler.lexicon.token_table
ctc_topo.draw(f"{TMP_DIR}/ctc_topo.svg", title="ctc_topo")
print("p num arcs", P.num_arcs)
print("ctc_topo num arcs", ctc_topo.num_arcs)
print("ctc_topo_P num arcs", ctc_topo_P.num_arcs)
texts = ["cat at ac at", "at ac cat zoo", "cat zoo"]
transcript_fsa = graph_compiler.build_transcript_fsa(texts)
transcript_fsa[0].draw(f"{TMP_DIR}/cat_at_ac_at.svg", title="cat_at_ac_at")
transcript_fsa[1].draw(
f"{TMP_DIR}/at_ac_cat_zoo.svg", title="at_ac_cat_zoo"
)
transcript_fsa[2].draw(f"{TMP_DIR}/cat_zoo.svg", title="cat_zoo")
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
num_graphs[0].draw(
f"{TMP_DIR}/num_cat_at_ac_at.svg", title="num_cat_at_ac_at"
)
num_graphs[1].draw(
f"{TMP_DIR}/num_at_ac_cat_zoo.svg", title="num_at_ac_cat_zoo"
)
num_graphs[2].draw(f"{TMP_DIR}/num_cat_zoo.svg", title="num_cat_zoo")
den_graphs[0].draw(
f"{TMP_DIR}/den_cat_at_ac_at.svg", title="den_cat_at_ac_at"
)
den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo")
texts = ["cat at cat", "ac at ca"]
token_ids = graph_compiler.texts_to_ids(texts)
# c a t a t c a t a c a t SPN
expected_token_ids = [[3, 2, 4, 2, 4, 3, 2, 4], [2, 3, 2, 4, 1]]
assert token_ids == expected_token_ids
def test_main():
generate_test_data()
mmi_graph_compiler_test()
if USING_PYTEST:
delete_test_data()
def main():
test_main()
if __name__ == "__main__" and not USING_PYTEST:
main()