Export torch script model for Aishell

This commit is contained in:
pkufool 2021-11-18 12:52:26 +08:00
parent 8f91ed2fbe
commit 83e6265f79
21 changed files with 441 additions and 254 deletions

View File

@ -38,14 +38,13 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
) )
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -113,17 +112,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -544,13 +532,6 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=84,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=25,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""It contains language related input files such as "lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 4,
"num_decoder_layers": 6,
}
)
return params
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,98 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -190,7 +190,6 @@ def get_params() -> AttributeDict:
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )
return params return params

View File

@ -38,12 +38,12 @@ from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
@ -83,8 +84,8 @@ class Transformer(nn.Module):
if subsampling_factor != 4: if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.") raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape [N, T, num_classes] # self.encoder_embed converts the input of shape (N, T, num_classes)
# to the shape [N, T//subsampling_factor, d_model]. # to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously: # That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor # (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model # (2) embedding: num_classes -> d_model
@ -152,7 +153,7 @@ class Transformer(nn.Module):
d_model, self.decoder_num_class d_model, self.decoder_num_class
) )
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss()
else: else:
self.decoder_criterion = None self.decoder_criterion = None
@ -162,7 +163,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
x: x:
The input tensor. Its shape is [N, T, C]. The input tensor. Its shape is (N, T, C).
supervision: supervision:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -171,17 +172,17 @@ class Transformer(nn.Module):
Returns: Returns:
Return a tuple containing 3 tensors: Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is [N, T, C] - CTC output for ctc decoding. Its shape is (N, T, C)
- Encoder output with shape [T, N, C]. It can be used as key and - Encoder output with shape (T, N, C). It can be used as key and
value for the decoder. value for the decoder.
- Encoder output padding mask. It can be used as - Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is [N, T]. memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None. It is None if `supervision` is None.
""" """
if self.use_feat_batchnorm: if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder( encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision x, supervision
) )
@ -195,7 +196,7 @@ class Transformer(nn.Module):
Args: Args:
x: x:
The model input. Its shape is [N, T, C]. The model input. Its shape is (N, T, C).
supervisions: supervisions:
Supervision in lhotse format. Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -206,8 +207,8 @@ class Transformer(nn.Module):
padding mask for the decoder. padding mask for the decoder.
Returns: Returns:
Return a tuple with two tensors: Return a tuple with two tensors:
- The encoder output, with shape [T, N, C] - The encoder output, with shape (T, N, C)
- encoder padding mask, with shape [N, T]. - encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None. The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder. It is used as memory key padding mask in the decoder.
""" """
@ -225,17 +226,18 @@ class Transformer(nn.Module):
Args: Args:
x: x:
The output tensor from the transformer encoder. The output tensor from the transformer encoder.
Its shape is [T, N, C] Its shape is (T, N, C)
Returns: Returns:
Return a tensor that can be used for CTC decoding. Return a tensor that can be used for CTC decoding.
Its shape is [N, T, C] Its shape is (N, T, C)
""" """
x = self.encoder_output_layer(x) x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x return x
@torch.jit.export
def decoder_forward( def decoder_forward(
self, self,
memory: torch.Tensor, memory: torch.Tensor,
@ -247,7 +249,7 @@ class Transformer(nn.Module):
""" """
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
token_ids: token_ids:
@ -264,11 +266,15 @@ class Transformer(nn.Module):
""" """
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
@ -301,18 +307,19 @@ class Transformer(nn.Module):
return decoder_loss return decoder_loss
@torch.jit.export
def decoder_nll( def decoder_nll(
self, self,
memory: torch.Tensor, memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]], token_ids: List[torch.Tensor],
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
token_ids: token_ids:
@ -328,14 +335,23 @@ class Transformer(nn.Module):
""" """
# The common part between this function and decoder_forward could be # The common part between this function and decoder_forward could be
# extracted as a separate function. # extracted as a separate function.
if isinstance(token_ids[0], torch.Tensor):
# This branch is executed by torchscript in C++.
# See https://github.com/k2-fsa/k2/pull/870
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
token_ids = [tolist(t) for t in token_ids]
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out] ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
@ -649,24 +665,24 @@ class PositionalEncoding(nn.Module):
self.d_model = d_model self.d_model = d_model
self.xscale = math.sqrt(self.d_model) self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.pe = None # not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None: def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required. """Extend the time t in the positional encoding if required.
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is [N, T, d_model]. If T > T1, then we change the shape of self.pe is (N, T, d_model). If T > T1, then we change the shape of self.pe
to [N, T, d_model]. Otherwise, nothing is done. to (N, T, d_model). Otherwise, nothing is done.
Args: Args:
x: x:
It is a tensor of shape [N, T, C]. It is a tensor of shape (N, T, C).
Returns: Returns:
Return None. Return None.
""" """
if self.pe is not None: if self.pe is not None:
if self.pe.size(1) >= x.size(1): if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
@ -678,7 +694,7 @@ class PositionalEncoding(nn.Module):
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
# Now pe is of shape [1, T, d_model], where T is x.size(1) # Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -687,10 +703,10 @@ class PositionalEncoding(nn.Module):
Args: Args:
x: x:
Its shape is [N, T, C] Its shape is (N, T, C)
Returns: Returns:
Return a tensor of shape [N, T, C] Return a tensor of shape (N, T, C)
""" """
self.extend_pe(x) self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :] x = x * self.xscale + self.pe[:, : x.size(1), :]
@ -784,73 +800,6 @@ class Noam(object):
setattr(self, key, value) setattr(self, key, value)
class LabelSmoothingLoss(nn.Module):
"""
Label-smoothing loss. KL-divergence between
q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
Args:
size: the number of class
padding_idx: padding_idx: ignored class id
smoothing: smoothing rate (0.0 means the conventional CE)
normalize_length: normalize loss by sequence length if True
criterion: loss function to be smoothed
"""
def __init__(
self,
size: int,
padding_idx: int = -1,
smoothing: float = 0.1,
normalize_length: bool = False,
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
) -> None:
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
assert 0.0 < smoothing <= 1.0
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.padding_id of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.size(2) == self.size
# batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
# denom = total if self.normalize_length else batch_size
denom = total if self.normalize_length else 1
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
def encoder_padding_mask( def encoder_padding_mask(
max_len: int, supervisions: Optional[Supervisions] = None max_len: int, supervisions: Optional[Supervisions] = None
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
@ -972,10 +921,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist starts Return a new list-of-list, where each sublist starts
with SOS ID. with SOS ID.
""" """
ans = [] return [[sos_id] + utt for utt in token_ids]
for utt in token_ids:
ans.append([sos_id] + utt)
return ans
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
@ -992,7 +938,9 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist ends Return a new list-of-list, where each sublist ends
with EOS ID. with EOS ID.
""" """
ans = [] return [utt + [eos_id] for utt in token_ids]
for utt in token_ids:
ans.append(utt + [eos_id])
return ans def tolist(t: torch.Tensor) -> List[int]:
"""Used by jit"""
return torch.jit.annotate(List[int], t.tolist())

View File

@ -40,14 +40,13 @@ from icefall.decode import (
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -122,17 +121,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -671,13 +659,6 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -36,7 +36,7 @@ from icefall.decode import (
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -256,7 +256,6 @@ def main():
params.num_decoder_layers = 0 params.num_decoder_layers = 0
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -41,12 +41,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -36,10 +36,10 @@ from icefall.decode import (
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions, encode_supervisions,
get_env_info,
setup_logger, setup_logger,
str2bool, str2bool,
) )

View File

@ -14,10 +14,10 @@ from model import Tdnn
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_env_info,
get_texts, get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,

View File

@ -29,7 +29,7 @@ from model import Tdnn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_env_info, get_texts from icefall.utils import AttributeDict, get_texts
def get_parser(): def get_parser():
@ -116,7 +116,6 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -22,15 +22,10 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():

106
icefall/env.py Normal file
View File

@ -0,0 +1,106 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Wei Kang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
from pathlib import Path
from typing import Any, Dict
import k2
import k2.version
import lhotse
import torch
def get_git_sha1():
git_commit = (
subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run(
["git", "diff", "--shortstat"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
return git_commit
def get_git_date():
git_date = (
subprocess.run(
["git", "log", "-1", "--format=%ad", "--date=local"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_git_branch_name():
git_date = (
subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_env_info() -> Dict[str, Any]:
"""Get the environment information."""
return {
"k2-version": k2.version.__version__,
"k2-build-type": k2.version.__build_type__,
"k2-with-cuda": k2.with_cuda,
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(),
"icefall-git-date": get_git_date(),
"icefall-path": str(Path(__file__).resolve().parent.parent),
"k2-path": str(Path(k2.__file__).resolve()),
"lhotse-path": str(Path(lhotse.__file__).resolve()),
}

View File

@ -21,17 +21,15 @@ import collections
import logging import logging
import os import os
import subprocess import subprocess
import sys
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2 import k2
import k2.version import k2.version
import kaldialign import kaldialign
import lhotse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -137,85 +135,6 @@ def setup_logger(
logging.getLogger("").addHandler(console) logging.getLogger("").addHandler(console)
def get_git_sha1():
git_commit = (
subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run(
["git", "diff", "--shortstat"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
return git_commit
def get_git_date():
git_date = (
subprocess.run(
["git", "log", "-1", "--format=%ad", "--date=local"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_git_branch_name():
git_date = (
subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_env_info() -> Dict[str, Any]:
"""Get the environment information."""
return {
"k2-version": k2.version.__version__,
"k2-build-type": k2.version.__build_type__,
"k2-with-cuda": k2.with_cuda,
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(),
"icefall-git-date": get_git_date(),
"icefall-path": str(Path(__file__).resolve().parent.parent),
"k2-path": str(Path(k2.__file__).resolve()),
"lhotse-path": str(Path(lhotse.__file__).resolve()),
}
class AttributeDict(dict): class AttributeDict(dict):
def __getattr__(self, key): def __getattr__(self, key):
if key in self: if key in self: