Export torch script model for Aishell

This commit is contained in:
pkufool 2021-11-18 12:52:26 +08:00
parent 8f91ed2fbe
commit 83e6265f79
21 changed files with 441 additions and 254 deletions

View File

@ -38,14 +38,13 @@ from icefall.decode import (
one_best_decoding,
rescore_with_attention_decoder,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_env_info,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -113,17 +112,6 @@ def get_parser():
""",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -544,13 +532,6 @@ def main():
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=84,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=25,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""It contains language related input files such as "lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 4,
"num_decoder_layers": 6,
}
)
return params
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,98 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding,
rescore_with_attention_decoder,
)
from icefall.utils import AttributeDict, get_env_info, get_texts
from icefall.utils import AttributeDict, get_texts
def get_parser():
@ -190,7 +190,6 @@ def get_params() -> AttributeDict:
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params

View File

@ -38,12 +38,12 @@ from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
get_env_info,
setup_logger,
str2bool,
)

View File

@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
@ -83,8 +84,8 @@ class Transformer(nn.Module):
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape [N, T, num_classes]
# to the shape [N, T//subsampling_factor, d_model].
# self.encoder_embed converts the input of shape (N, T, num_classes)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model
@ -152,7 +153,7 @@ class Transformer(nn.Module):
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
self.decoder_criterion = LabelSmoothingLoss()
else:
self.decoder_criterion = None
@ -162,7 +163,7 @@ class Transformer(nn.Module):
"""
Args:
x:
The input tensor. Its shape is [N, T, C].
The input tensor. Its shape is (N, T, C).
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -171,17 +172,17 @@ class Transformer(nn.Module):
Returns:
Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is [N, T, C]
- Encoder output with shape [T, N, C]. It can be used as key and
- CTC output for ctc decoding. Its shape is (N, T, C)
- Encoder output with shape (T, N, C). It can be used as key and
value for the decoder.
- Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is [N, T].
memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
@ -195,7 +196,7 @@ class Transformer(nn.Module):
Args:
x:
The model input. Its shape is [N, T, C].
The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -206,8 +207,8 @@ class Transformer(nn.Module):
padding mask for the decoder.
Returns:
Return a tuple with two tensors:
- The encoder output, with shape [T, N, C]
- encoder padding mask, with shape [N, T].
- The encoder output, with shape (T, N, C)
- encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder.
"""
@ -225,17 +226,18 @@ class Transformer(nn.Module):
Args:
x:
The output tensor from the transformer encoder.
Its shape is [T, N, C]
Its shape is (T, N, C)
Returns:
Return a tensor that can be used for CTC decoding.
Its shape is [N, T, C]
Its shape is (N, T, C)
"""
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x
@torch.jit.export
def decoder_forward(
self,
memory: torch.Tensor,
@ -247,7 +249,7 @@ class Transformer(nn.Module):
"""
Args:
memory:
It's the output of the encoder with shape [T, N, C]
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
@ -264,11 +266,15 @@ class Transformer(nn.Module):
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device)
@ -301,18 +307,19 @@ class Transformer(nn.Module):
return decoder_loss
@torch.jit.export
def decoder_nll(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]],
token_ids: List[torch.Tensor],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape [T, N, C]
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
@ -328,14 +335,23 @@ class Transformer(nn.Module):
"""
# The common part between this function and decoder_forward could be
# extracted as a separate function.
if isinstance(token_ids[0], torch.Tensor):
# This branch is executed by torchscript in C++.
# See https://github.com/k2-fsa/k2/pull/870
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
token_ids = [tolist(t) for t in token_ids]
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
@ -649,25 +665,25 @@ class PositionalEncoding(nn.Module):
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
self.pe = None
# not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
to [N, T, d_model]. Otherwise, nothing is done.
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is (N, T, d_model). If T > T1, then we change the shape of self.pe
to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
It is a tensor of shape [N, T, C].
It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
@ -678,7 +694,7 @@ class PositionalEncoding(nn.Module):
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Now pe is of shape [1, T, d_model], where T is x.size(1)
# Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -687,10 +703,10 @@ class PositionalEncoding(nn.Module):
Args:
x:
Its shape is [N, T, C]
Its shape is (N, T, C)
Returns:
Return a tensor of shape [N, T, C]
Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
@ -784,73 +800,6 @@ class Noam(object):
setattr(self, key, value)
class LabelSmoothingLoss(nn.Module):
"""
Label-smoothing loss. KL-divergence between
q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
Args:
size: the number of class
padding_idx: padding_idx: ignored class id
smoothing: smoothing rate (0.0 means the conventional CE)
normalize_length: normalize loss by sequence length if True
criterion: loss function to be smoothed
"""
def __init__(
self,
size: int,
padding_idx: int = -1,
smoothing: float = 0.1,
normalize_length: bool = False,
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
) -> None:
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
assert 0.0 < smoothing <= 1.0
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.padding_id of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.size(2) == self.size
# batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
# denom = total if self.normalize_length else batch_size
denom = total if self.normalize_length else 1
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
def encoder_padding_mask(
max_len: int, supervisions: Optional[Supervisions] = None
) -> Optional[torch.Tensor]:
@ -972,10 +921,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
ans = []
for utt in token_ids:
ans.append([sos_id] + utt)
return ans
return [[sos_id] + utt for utt in token_ids]
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
@ -992,7 +938,9 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
Return a new list-of-list, where each sublist ends
with EOS ID.
"""
ans = []
for utt in token_ids:
ans.append(utt + [eos_id])
return ans
return [utt + [eos_id] for utt in token_ids]
def tolist(t: torch.Tensor) -> List[int]:
"""Used by jit"""
return torch.jit.annotate(List[int], t.tolist())

View File

@ -40,14 +40,13 @@ from icefall.decode import (
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_env_info,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -122,17 +121,6 @@ def get_parser():
""",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -671,13 +659,6 @@ def main():
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])

View File

@ -36,7 +36,7 @@ from icefall.decode import (
rescore_with_attention_decoder,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_env_info, get_texts
from icefall.utils import AttributeDict, get_texts
def get_parser():
@ -256,7 +256,6 @@ def main():
params.num_decoder_layers = 0
params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}")
device = torch.device("cpu")

View File

@ -41,12 +41,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
get_env_info,
setup_logger,
str2bool,
)

View File

@ -36,10 +36,10 @@ from icefall.decode import (
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_env_info,
get_texts,
setup_logger,
store_transcripts,

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_env_info, get_texts
from icefall.utils import AttributeDict, get_texts
def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}")
device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
get_env_info,
setup_logger,
str2bool,
)

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_env_info, get_texts
from icefall.utils import AttributeDict, get_texts
def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}")
device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
get_env_info,
setup_logger,
str2bool,
)

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_env_info, get_texts
from icefall.utils import AttributeDict, get_texts
def get_parser():
@ -159,7 +159,6 @@ def main():
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}")
device = torch.device("cpu")

View File

@ -40,13 +40,13 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
get_env_info,
setup_logger,
str2bool,
)

View File

@ -14,10 +14,10 @@ from model import Tdnn
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice, one_best_decoding
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_env_info,
get_texts,
setup_logger,
store_transcripts,

View File

@ -29,7 +29,7 @@ from model import Tdnn
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_env_info, get_texts
from icefall.utils import AttributeDict, get_texts
def get_parser():
@ -116,7 +116,6 @@ def main():
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
logging.info(f"{params}")
device = torch.device("cpu")

View File

@ -22,15 +22,10 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():

106
icefall/env.py Normal file
View File

@ -0,0 +1,106 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Wei Kang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
from pathlib import Path
from typing import Any, Dict
import k2
import k2.version
import lhotse
import torch
def get_git_sha1():
git_commit = (
subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run(
["git", "diff", "--shortstat"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
return git_commit
def get_git_date():
git_date = (
subprocess.run(
["git", "log", "-1", "--format=%ad", "--date=local"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_git_branch_name():
git_date = (
subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_env_info() -> Dict[str, Any]:
"""Get the environment information."""
return {
"k2-version": k2.version.__version__,
"k2-build-type": k2.version.__build_type__,
"k2-with-cuda": k2.with_cuda,
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(),
"icefall-git-date": get_git_date(),
"icefall-path": str(Path(__file__).resolve().parent.parent),
"k2-path": str(Path(k2.__file__).resolve()),
"lhotse-path": str(Path(lhotse.__file__).resolve()),
}

View File

@ -21,17 +21,15 @@ import collections
import logging
import os
import subprocess
import sys
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2
import k2.version
import kaldialign
import lhotse
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
@ -137,85 +135,6 @@ def setup_logger(
logging.getLogger("").addHandler(console)
def get_git_sha1():
git_commit = (
subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
dirty_commit = (
len(
subprocess.run(
["git", "diff", "--shortstat"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
> 0
)
git_commit = (
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
)
return git_commit
def get_git_date():
git_date = (
subprocess.run(
["git", "log", "-1", "--format=%ad", "--date=local"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_git_branch_name():
git_date = (
subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
check=True,
stdout=subprocess.PIPE,
)
.stdout.decode()
.rstrip("\n")
.strip()
)
return git_date
def get_env_info() -> Dict[str, Any]:
"""Get the environment information."""
return {
"k2-version": k2.version.__version__,
"k2-build-type": k2.version.__build_type__,
"k2-with-cuda": k2.with_cuda,
"k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(),
"icefall-git-date": get_git_date(),
"icefall-path": str(Path(__file__).resolve().parent.parent),
"k2-path": str(Path(k2.__file__).resolve()),
"lhotse-path": str(Path(lhotse.__file__).resolve()),
}
class AttributeDict(dict):
def __getattr__(self, key):
if key in self: