mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Merge branch 'master' into sort_results
This commit is contained in:
commit
dc6499a052
@ -22,8 +22,80 @@ ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
|
||||
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
log "Test exporting to ONNX format"
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--onnx 1
|
||||
|
||||
log "Export to torchscript model"
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit 1
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir $repo/exp \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--jit-trace 1
|
||||
|
||||
ls -lh $repo/exp/*.onnx
|
||||
ls -lh $repo/exp/*.pt
|
||||
|
||||
log "Decode with ONNX models"
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||
--onnx-joiner-filename $repo/exp/joiner.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
|
||||
--jit-filename $repo/exp/cpu_jit.pt \
|
||||
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
|
||||
|
||||
./pruned_transducer_stateless3/onnx_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decode with models exported by torch.jit.trace()"
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
log "Decode with models exported by torch.jit.script()"
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
|
||||
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
|
||||
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
|
||||
$repo/test_wavs/1089-134686-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0001.wav \
|
||||
$repo/test_wavs/1221-135766-0002.wav
|
||||
|
||||
|
||||
for sym in 1 2 3; do
|
||||
log "Greedy search with --max-sym-per-frame $sym"
|
||||
|
||||
|
@ -35,7 +35,7 @@ on:
|
||||
|
||||
jobs:
|
||||
run_librispeech_pruned_transducer_stateless3_2022_05_13:
|
||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
|
@ -79,6 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
|
||||
cd -
|
||||
|
||||
# install lhotse
|
||||
RUN pip install torchaudio==0.7.2
|
||||
RUN pip install git+https://github.com/lhotse-speech/lhotse
|
||||
#RUN pip install lhotse
|
||||
|
||||
|
@ -447,6 +447,17 @@ def compute_loss(
|
||||
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = supervisions["num_frames"].sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -605,6 +605,15 @@ def compute_loss(
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -457,9 +457,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -674,13 +671,7 @@ def train_one_epoch(
|
||||
global_step=params.batch_idx_train,
|
||||
)
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -728,7 +719,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -738,7 +728,6 @@ def train_one_epoch(
|
||||
sampler=train_dl.sampler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
@ -893,13 +882,14 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
if params.start_batch <= 0:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
fix_random_seed(params.seed + epoch)
|
||||
|
@ -155,7 +155,8 @@ class Conformer(EncoderInterface):
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
@ -787,6 +788,14 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
if torch.jit.is_tracing():
|
||||
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
|
||||
# It assumes that the maximum input won't have more than
|
||||
# 10k frames.
|
||||
#
|
||||
# TODO(fangjun): Use torch.jit.script() for this module
|
||||
max_len = 10000
|
||||
|
||||
self.d_model = d_model
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
@ -992,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
|
||||
time1 means the length of query vector.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
@ -1006,20 +1015,32 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
|
||||
time2 = time1 + left_context
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
if not torch.jit.is_tracing():
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time2),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
if torch.jit.is_tracing():
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(time2)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
|
||||
x = x.reshape(-1, n)
|
||||
x = torch.gather(x, dim=1, index=indexes)
|
||||
x = x.reshape(batch_size, num_heads, time1, time2)
|
||||
return x
|
||||
else:
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time2),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
@ -1090,13 +1111,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
if not torch.jit.is_tracing():
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
if not torch.jit.is_tracing():
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
@ -1209,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
src_len = k.size(0)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
if key_padding_mask is not None and not torch.jit.is_tracing():
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
key_padding_mask.size(0), bsz
|
||||
)
|
||||
@ -1220,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
if not torch.jit.is_tracing():
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
@ -1255,11 +1280,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
if not torch.jit.is_tracing():
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
@ -1318,7 +1344,14 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
|
||||
if not torch.jit.is_tracing():
|
||||
assert list(attn_output.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
head_dim,
|
||||
]
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
|
@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -77,7 +79,9 @@ class Decoder(nn.Module):
|
||||
# It is to support torch script
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
def forward(
|
||||
self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
@ -88,18 +92,24 @@ class Decoder(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
if isinstance(need_pad, torch.Tensor):
|
||||
# This is for torch.jit.trace(), which cannot handle the case
|
||||
# when the input argument is not a tensor.
|
||||
need_pad = bool(need_pad)
|
||||
|
||||
y = y.to(torch.int64)
|
||||
embedding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
if need_pad:
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
if not torch.jit.is_tracing():
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
|
@ -52,10 +52,10 @@ class Joiner(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
if not torch.jit.is_tracing():
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
|
@ -152,7 +152,8 @@ class BasicNorm(torch.nn.Module):
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||
+ self.eps.exp()
|
||||
@ -423,7 +424,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_abs = max_abs
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
@ -472,7 +473,7 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
else:
|
||||
return DoubleSwishFunction.apply(x)
|
||||
@ -494,9 +495,6 @@ class ScaledEmbedding(nn.Module):
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
||||
(initialized to zeros) whenever it encounters the index.
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
||||
is renormalized to have norm :attr:`max_norm`.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
||||
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
||||
the words in the mini-batch. Default ``False``.
|
||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||
@ -505,7 +503,7 @@ class ScaledEmbedding(nn.Module):
|
||||
initial_speed (float, optional): This affects how fast the parameter will
|
||||
learn near the start of training; you can set it to a value less than
|
||||
one if you suspect that a module is contributing to instability near
|
||||
the start of training. Nnote: regardless of the use of this option,
|
||||
the start of training. Note: regardless of the use of this option,
|
||||
it's best to use schedulers like Noam that have a warm-up period.
|
||||
Alternatively you can set it to more than 1 if you want it to
|
||||
initially train faster. Must be greater than 0.
|
||||
|
@ -503,9 +503,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -724,13 +721,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -765,7 +756,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -777,7 +767,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
@ -944,7 +933,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
|
@ -19,14 +19,67 @@
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
It will also generate 3 other files: `encoder_jit_script.pt`,
|
||||
`decoder_jit_script.pt`, and `joiner_jit_script.pt`.
|
||||
|
||||
(2) Export to torchscript model using torch.jit.trace()
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit-trace 1
|
||||
|
||||
It will generates 3 files: `encoder_jit_trace.pt`,
|
||||
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
|
||||
|
||||
|
||||
(3) Export to ONNX format
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following three files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
|
||||
|
||||
(4) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless3/decode.py`,
|
||||
you can do:
|
||||
@ -42,14 +95,31 @@ you can do:
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
|
||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
@ -114,6 +184,42 @@ def get_parser():
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate 4 files:
|
||||
- encoder_jit_script.pt
|
||||
- decoder_jit_script.pt
|
||||
- joiner_jit_script.pt
|
||||
- cpu_jit.pt (which combines the above 3 files)
|
||||
|
||||
Check ./jit_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-trace",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.trace.
|
||||
It will generate 3 files:
|
||||
- encoder_jit_trace.pt
|
||||
- decoder_jit_trace.pt
|
||||
- joiner_jit_trace.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. Three files will be generated:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
|
||||
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -139,6 +245,299 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_script(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.script()
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
script_model = torch.jit.script(encoder_model)
|
||||
script_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_script(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.script()
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
script_model = torch.jit.script(decoder_model)
|
||||
script_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_script(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
script_model = torch.jit.script(joiner_model)
|
||||
script_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
traced_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_encoder_model_onnx(
|
||||
encoder_model: nn.Module,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||
|
||||
and it has two outputs:
|
||||
|
||||
- encoder_out, a tensor of shape (N, T, C)
|
||||
- encoder_out_lens, a tensor of shape (N,)
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
# encoder_model = torch.jit.script(encoder_model)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
warmup = 1.0
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens, warmup),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_lens", "warmup"],
|
||||
output_names=["encoder_out", "encoder_out_lens"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
"encoder_out": {0: "N", 1: "T"},
|
||||
"encoder_out_lens": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX format.
|
||||
|
||||
The exported model has one input:
|
||||
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
|
||||
and has one output:
|
||||
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||
# in this case
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y, need_pad),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y", "need_pad"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
|
||||
and has one output:
|
||||
|
||||
- joiner_out: a tensor of shape (N, vocab_size)
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
project_input = True
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(encoder_out, decoder_out, project_input),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out", "decoder_out", "project_input"],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
def export_all_in_one_onnx(
|
||||
encoder_filename: str,
|
||||
decoder_filename: str,
|
||||
joiner_filename: str,
|
||||
all_in_one_filename: str,
|
||||
):
|
||||
encoder_onnx = onnx.load(encoder_filename)
|
||||
decoder_onnx = onnx.load(decoder_filename)
|
||||
joiner_onnx = onnx.load(joiner_filename)
|
||||
|
||||
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
|
||||
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
|
||||
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
|
||||
|
||||
combined_model = onnx.compose.merge_models(
|
||||
encoder_onnx, decoder_onnx, io_map={}
|
||||
)
|
||||
combined_model = onnx.compose.merge_models(
|
||||
combined_model, joiner_onnx, io_map={}
|
||||
)
|
||||
onnx.save(combined_model, all_in_one_filename)
|
||||
logging.info(f"Saved to {all_in_one_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
@ -165,7 +564,7 @@ def main():
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
model = get_transducer_model(params, enable_giga=False)
|
||||
|
||||
model.to(device)
|
||||
|
||||
@ -185,7 +584,9 @@ def main():
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
@ -196,14 +597,47 @@ def main():
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.eval()
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
if params.jit:
|
||||
if params.onnx is True:
|
||||
opset_version = 11
|
||||
logging.info("Exporting to onnx format")
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
export_encoder_model_onnx(
|
||||
model.encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
|
||||
export_all_in_one_onnx(
|
||||
encoder_filename,
|
||||
decoder_filename,
|
||||
joiner_filename,
|
||||
all_in_one_filename,
|
||||
)
|
||||
elif params.jit is True:
|
||||
logging.info("Using torch.jit.script()")
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
@ -214,8 +648,29 @@ def main():
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
# Also export encoder/decoder/joiner separately
|
||||
encoder_filename = params.exp_dir / "encoder_jit_script.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder_jit_script.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner_jit_script.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
elif params.jit_trace is True:
|
||||
logging.info("Using torch.jit.trace()")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
else:
|
||||
logging.info("Not using torch.jit.script")
|
||||
logging.info("Not using torchscript")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
|
338
egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
Executable file
338
egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
Executable file
@ -0,0 +1,338 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, either exported by `torch.jit.trace()`
|
||||
or by `torch.jit.script()`, and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit-trace 1
|
||||
|
||||
or
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_trace.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
or
|
||||
|
||||
./pruned_transducer_stateless3/jit_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_script.pt \
|
||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_script.pt \
|
||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_script.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: torch.jit.ScriptModule,
|
||||
joiner: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_size: int,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
decoder:
|
||||
The decoder model.
|
||||
joiner:
|
||||
The joiner model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
encoder = torch.jit.load(args.encoder_model_filename)
|
||||
decoder = torch.jit.load(args.decoder_model_filename)
|
||||
joiner = torch.jit.load(args.joiner_model_filename)
|
||||
|
||||
encoder.eval()
|
||||
decoder.eval()
|
||||
joiner.eval()
|
||||
|
||||
encoder.to(device)
|
||||
decoder.to(device)
|
||||
joiner.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_size=args.context_size,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
199
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
Executable file
199
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
Executable file
@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script checks that exported onnx models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-encoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-decoder-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-joiner-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx joiner model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_session: ort.InferenceSession,
|
||||
):
|
||||
encoder_inputs = encoder_session.get_inputs()
|
||||
assert encoder_inputs[0].name == "x"
|
||||
assert encoder_inputs[1].name == "x_lens"
|
||||
assert encoder_inputs[0].shape == ["N", "T", 80]
|
||||
assert encoder_inputs[1].shape == ["N"]
|
||||
|
||||
for N in [1, 5]:
|
||||
for T in [12, 25]:
|
||||
print("N, T", N, T)
|
||||
x = torch.rand(N, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||
x_lens[0] = T
|
||||
|
||||
encoder_inputs = {
|
||||
"x": x.numpy(),
|
||||
"x_lens": x_lens.numpy(),
|
||||
}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
["encoder_out", "encoder_out_lens"],
|
||||
encoder_inputs,
|
||||
)
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||
(encoder_out - torch_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
decoder_session: ort.InferenceSession,
|
||||
):
|
||||
decoder_inputs = decoder_session.get_inputs()
|
||||
assert decoder_inputs[0].name == "y"
|
||||
assert decoder_inputs[0].shape == ["N", 2]
|
||||
for N in [1, 5, 10]:
|
||||
y = torch.randint(low=1, high=500, size=(10, 2))
|
||||
|
||||
decoder_inputs = {"y": y.numpy()}
|
||||
decoder_out = decoder_session.run(
|
||||
["decoder_out"],
|
||||
decoder_inputs,
|
||||
)[0]
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
|
||||
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
|
||||
(decoder_out - torch_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
model: torch.jit.ScriptModule,
|
||||
joiner_session: ort.InferenceSession,
|
||||
):
|
||||
joiner_inputs = joiner_session.get_inputs()
|
||||
assert joiner_inputs[0].name == "encoder_out"
|
||||
assert joiner_inputs[0].shape == ["N", 512]
|
||||
|
||||
assert joiner_inputs[1].name == "decoder_out"
|
||||
assert joiner_inputs[1].shape == ["N", 512]
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
encoder_out = torch.rand(N, 512)
|
||||
decoder_out = torch.rand(N, 512)
|
||||
|
||||
joiner_inputs = {
|
||||
"encoder_out": encoder_out.numpy(),
|
||||
"decoder_out": decoder_out.numpy(),
|
||||
}
|
||||
joiner_out = joiner_session.run(["logit"], joiner_inputs)[0]
|
||||
joiner_out = torch.from_numpy(joiner_out)
|
||||
|
||||
torch_joiner_out = model.joiner(
|
||||
encoder_out,
|
||||
decoder_out,
|
||||
project_input=True,
|
||||
)
|
||||
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
|
||||
(joiner_out - torch_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = torch.jit.load(args.jit_filename)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
logging.info("Test encoder")
|
||||
encoder_session = ort.InferenceSession(
|
||||
args.onnx_encoder_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_encoder(model, encoder_session)
|
||||
|
||||
logging.info("Test decoder")
|
||||
decoder_session = ort.InferenceSession(
|
||||
args.onnx_decoder_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_decoder(model, decoder_session)
|
||||
|
||||
logging.info("Test joiner")
|
||||
joiner_session = ort.InferenceSession(
|
||||
args.onnx_joiner_filename,
|
||||
sess_options=options,
|
||||
)
|
||||
test_joiner(model, joiner_session)
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220727)
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
284
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py
Executable file
284
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py
Executable file
@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Yunus Emre Ozkose)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script checks that exported onnx models produce the same output
|
||||
with the given torchscript model for the same input.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import onnxruntime
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the torchscript model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx-all-in-one-filename",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to the onnx all in one model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def test_encoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_session: ort.InferenceSession,
|
||||
):
|
||||
encoder_inputs = encoder_session.get_inputs()
|
||||
assert encoder_inputs[0].shape == ["N", "T", 80]
|
||||
assert encoder_inputs[1].shape == ["N"]
|
||||
encoder_input_names = [i.name for i in encoder_inputs]
|
||||
encoder_output_names = [i.name for i in encoder_session.get_outputs()]
|
||||
|
||||
for N in [1, 5]:
|
||||
for T in [12, 25]:
|
||||
print("N, T", N, T)
|
||||
x = torch.rand(N, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
|
||||
x_lens[0] = T
|
||||
|
||||
encoder_inputs = {
|
||||
encoder_input_names[0]: x.numpy(),
|
||||
encoder_input_names[1]: x_lens.numpy(),
|
||||
}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
[encoder_output_names[1], encoder_output_names[0]],
|
||||
encoder_inputs,
|
||||
)
|
||||
|
||||
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
|
||||
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
|
||||
(encoder_out - torch_encoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_decoder(
|
||||
model: torch.jit.ScriptModule,
|
||||
decoder_session: ort.InferenceSession,
|
||||
):
|
||||
decoder_inputs = decoder_session.get_inputs()
|
||||
assert decoder_inputs[0].shape == ["N", 2]
|
||||
decoder_input_names = [i.name for i in decoder_inputs]
|
||||
decoder_output_names = [i.name for i in decoder_session.get_outputs()]
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
y = torch.randint(low=1, high=500, size=(10, 2))
|
||||
|
||||
decoder_inputs = {decoder_input_names[0]: y.numpy()}
|
||||
decoder_out = decoder_session.run(
|
||||
[decoder_output_names[0]],
|
||||
decoder_inputs,
|
||||
)[0]
|
||||
decoder_out = torch.from_numpy(decoder_out)
|
||||
|
||||
torch_decoder_out = model.decoder(y, need_pad=False)
|
||||
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
|
||||
(decoder_out - torch_decoder_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def test_joiner(
|
||||
model: torch.jit.ScriptModule,
|
||||
joiner_session: ort.InferenceSession,
|
||||
):
|
||||
joiner_inputs = joiner_session.get_inputs()
|
||||
assert joiner_inputs[0].shape == ["N", 512]
|
||||
assert joiner_inputs[1].shape == ["N", 512]
|
||||
joiner_input_names = [i.name for i in joiner_inputs]
|
||||
joiner_output_names = [i.name for i in joiner_session.get_outputs()]
|
||||
|
||||
for N in [1, 5, 10]:
|
||||
encoder_out = torch.rand(N, 512)
|
||||
decoder_out = torch.rand(N, 512)
|
||||
|
||||
joiner_inputs = {
|
||||
joiner_input_names[0]: encoder_out.numpy(),
|
||||
joiner_input_names[1]: decoder_out.numpy(),
|
||||
}
|
||||
joiner_out = joiner_session.run(
|
||||
[joiner_output_names[0]], joiner_inputs
|
||||
)[0]
|
||||
joiner_out = torch.from_numpy(joiner_out)
|
||||
|
||||
torch_joiner_out = model.joiner(
|
||||
encoder_out,
|
||||
decoder_out,
|
||||
project_input=True,
|
||||
)
|
||||
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
|
||||
(joiner_out - torch_joiner_out).abs().max()
|
||||
)
|
||||
|
||||
|
||||
def extract_sub_model(
|
||||
onnx_graph: onnx.ModelProto,
|
||||
input_op_names: list,
|
||||
output_op_names: list,
|
||||
non_verbose=False,
|
||||
):
|
||||
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
|
||||
graph = gs.import_onnx(onnx_graph)
|
||||
graph.cleanup().toposort()
|
||||
|
||||
# Extraction of input OP and output OP
|
||||
graph_node_inputs = [
|
||||
graph_nodes
|
||||
for graph_nodes in graph.nodes
|
||||
for graph_nodes_input in graph_nodes.inputs
|
||||
if graph_nodes_input.name in input_op_names
|
||||
]
|
||||
graph_node_outputs = [
|
||||
graph_nodes
|
||||
for graph_nodes in graph.nodes
|
||||
for graph_nodes_output in graph_nodes.outputs
|
||||
if graph_nodes_output.name in output_op_names
|
||||
]
|
||||
|
||||
# Init graph INPUT/OUTPUT
|
||||
graph.inputs.clear()
|
||||
graph.outputs.clear()
|
||||
|
||||
# Update graph INPUT/OUTPUT
|
||||
graph.inputs = [
|
||||
graph_node_input
|
||||
for graph_node in graph_node_inputs
|
||||
for graph_node_input in graph_node.inputs
|
||||
if graph_node_input.shape
|
||||
]
|
||||
graph.outputs = [
|
||||
graph_node_output
|
||||
for graph_node in graph_node_outputs
|
||||
for graph_node_output in graph_node.outputs
|
||||
]
|
||||
|
||||
# Cleanup
|
||||
graph.cleanup().toposort()
|
||||
|
||||
# Shape Estimation
|
||||
extracted_graph = None
|
||||
try:
|
||||
extracted_graph = onnx.shape_inference.infer_shapes(
|
||||
gs.export_onnx(graph)
|
||||
)
|
||||
except Exception:
|
||||
extracted_graph = gs.export_onnx(graph)
|
||||
if not non_verbose:
|
||||
print(
|
||||
"WARNING: "
|
||||
+ "The input shape of the next OP does not match the output shape. "
|
||||
+ "Be sure to open the .onnx file to verify the certainty of the geometry."
|
||||
)
|
||||
return extracted_graph
|
||||
|
||||
|
||||
def extract_encoder(onnx_model: onnx.ModelProto):
|
||||
encoder_ = extract_sub_model(
|
||||
onnx_model,
|
||||
["encoder/x", "encoder/x_lens"],
|
||||
["encoder/encoder_out", "encoder/encoder_out_lens"],
|
||||
False,
|
||||
)
|
||||
onnx.save(encoder_, "tmp_encoder.onnx")
|
||||
onnx.checker.check_model(encoder_)
|
||||
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
|
||||
os.remove("tmp_encoder.onnx")
|
||||
return sess
|
||||
|
||||
|
||||
def extract_decoder(onnx_model: onnx.ModelProto):
|
||||
decoder_ = extract_sub_model(
|
||||
onnx_model, ["decoder/y"], ["decoder/decoder_out"], False
|
||||
)
|
||||
onnx.save(decoder_, "tmp_decoder.onnx")
|
||||
onnx.checker.check_model(decoder_)
|
||||
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
|
||||
os.remove("tmp_decoder.onnx")
|
||||
return sess
|
||||
|
||||
|
||||
def extract_joiner(onnx_model: onnx.ModelProto):
|
||||
joiner_ = extract_sub_model(
|
||||
onnx_model,
|
||||
["joiner/encoder_out", "joiner/decoder_out"],
|
||||
["joiner/logit"],
|
||||
False,
|
||||
)
|
||||
onnx.save(joiner_, "tmp_joiner.onnx")
|
||||
onnx.checker.check_model(joiner_)
|
||||
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
|
||||
os.remove("tmp_joiner.onnx")
|
||||
return sess
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
model = torch.jit.load(args.jit_filename)
|
||||
onnx_model = onnx.load(args.onnx_all_in_one_filename)
|
||||
|
||||
options = ort.SessionOptions()
|
||||
options.inter_op_num_threads = 1
|
||||
options.intra_op_num_threads = 1
|
||||
|
||||
logging.info("Test encoder")
|
||||
encoder_session = extract_encoder(onnx_model)
|
||||
test_encoder(model, encoder_session)
|
||||
|
||||
logging.info("Test decoder")
|
||||
decoder_session = extract_decoder(onnx_model)
|
||||
test_decoder(model, decoder_session)
|
||||
|
||||
logging.info("Test joiner")
|
||||
joiner_session = extract_joiner(onnx_model)
|
||||
test_joiner(model, joiner_session)
|
||||
logging.info("Finished checking ONNX models")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220727)
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
337
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
Executable file
337
egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
Executable file
@ -0,0 +1,337 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads ONNX models and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--onnx 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless3/jit_trace_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
|
||||
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
|
||||
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Context size of the decoder model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: ort.InferenceSession,
|
||||
joiner: ort.InferenceSession,
|
||||
encoder_out: np.ndarray,
|
||||
encoder_out_lens: np.ndarray,
|
||||
context_size: int,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
decoder:
|
||||
The decoder model.
|
||||
joiner:
|
||||
The joiner model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
context_size:
|
||||
The context size of the decoder model.
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
encoder_out = torch.from_numpy(encoder_out)
|
||||
encoder_out_lens = torch.from_numpy(encoder_out_lens)
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input_nodes = decoder.get_inputs()
|
||||
decoder_output_nodes = decoder.get_outputs()
|
||||
|
||||
joiner_input_nodes = joiner.get_inputs()
|
||||
joiner_output_nodes = joiner.get_outputs()
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = joiner.run(
|
||||
[joiner_output_nodes[0].name],
|
||||
{
|
||||
joiner_input_nodes[0].name: current_encoder_out.numpy(),
|
||||
joiner_input_nodes[1].name: decoder_out,
|
||||
},
|
||||
)[0]
|
||||
logits = torch.from_numpy(logits)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = decoder.run(
|
||||
[decoder_output_nodes[0].name],
|
||||
{
|
||||
decoder_input_nodes[0].name: decoder_input.numpy(),
|
||||
},
|
||||
)[0].squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
encoder = ort.InferenceSession(
|
||||
args.encoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
decoder = ort.InferenceSession(
|
||||
args.decoder_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
joiner = ort.InferenceSession(
|
||||
args.joiner_model_filename,
|
||||
sess_options=session_opts,
|
||||
)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = args.sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
|
||||
|
||||
encoder_input_nodes = encoder.get_inputs()
|
||||
encoder_out_nodes = encoder.get_outputs()
|
||||
encoder_out, encoder_out_lens = encoder.run(
|
||||
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
|
||||
{
|
||||
encoder_input_nodes[0].name: features.numpy(),
|
||||
encoder_input_nodes[1].name: feature_lengths.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
context_size=args.context_size,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -15,7 +15,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless3/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless3/pretrained.py \
|
||||
|
@ -0,0 +1,207 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
|
||||
`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
|
||||
`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.
|
||||
|
||||
The scaled version are required only in the training time. It simplifies our
|
||||
life by converting them to their non-scaled version during inference.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
|
||||
|
||||
|
||||
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
||||
"""Convert an instance of ScaledLinear to nn.Linear.
|
||||
|
||||
Args:
|
||||
scaled_linear:
|
||||
The layer to be converted.
|
||||
Returns:
|
||||
Return a linear layer. It satisfies:
|
||||
|
||||
scaled_linear(x) == linear(x)
|
||||
|
||||
for any given input tensor `x`.
|
||||
"""
|
||||
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
|
||||
|
||||
weight = scaled_linear.get_weight()
|
||||
bias = scaled_linear.get_bias()
|
||||
has_bias = bias is not None
|
||||
|
||||
linear = torch.nn.Linear(
|
||||
in_features=scaled_linear.in_features,
|
||||
out_features=scaled_linear.out_features,
|
||||
bias=True, # otherwise, it throws errors when converting to PNNX format.
|
||||
device=weight.device,
|
||||
)
|
||||
linear.weight.data.copy_(weight)
|
||||
|
||||
if has_bias:
|
||||
linear.bias.data.copy_(bias)
|
||||
else:
|
||||
linear.bias.data.zero_()
|
||||
|
||||
return linear
|
||||
|
||||
|
||||
def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d:
|
||||
"""Convert an instance of ScaledConv1d to nn.Conv1d.
|
||||
|
||||
Args:
|
||||
scaled_conv1d:
|
||||
The layer to be converted.
|
||||
Returns:
|
||||
Return an instance of nn.Conv1d that has the same `forward()` behavior
|
||||
of the given `scaled_conv1d`.
|
||||
"""
|
||||
assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d)
|
||||
|
||||
weight = scaled_conv1d.get_weight()
|
||||
bias = scaled_conv1d.get_bias()
|
||||
has_bias = bias is not None
|
||||
|
||||
conv1d = nn.Conv1d(
|
||||
in_channels=scaled_conv1d.in_channels,
|
||||
out_channels=scaled_conv1d.out_channels,
|
||||
kernel_size=scaled_conv1d.kernel_size,
|
||||
stride=scaled_conv1d.stride,
|
||||
padding=scaled_conv1d.padding,
|
||||
dilation=scaled_conv1d.dilation,
|
||||
groups=scaled_conv1d.groups,
|
||||
bias=scaled_conv1d.bias is not None,
|
||||
padding_mode=scaled_conv1d.padding_mode,
|
||||
)
|
||||
|
||||
conv1d.weight.data.copy_(weight)
|
||||
if has_bias:
|
||||
conv1d.bias.data.copy_(bias)
|
||||
|
||||
return conv1d
|
||||
|
||||
|
||||
def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
|
||||
"""Convert an instance of ScaledConv2d to nn.Conv2d.
|
||||
|
||||
Args:
|
||||
scaled_conv2d:
|
||||
The layer to be converted.
|
||||
Returns:
|
||||
Return an instance of nn.Conv2d that has the same `forward()` behavior
|
||||
of the given `scaled_conv2d`.
|
||||
"""
|
||||
assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d)
|
||||
|
||||
weight = scaled_conv2d.get_weight()
|
||||
bias = scaled_conv2d.get_bias()
|
||||
has_bias = bias is not None
|
||||
|
||||
conv2d = nn.Conv2d(
|
||||
in_channels=scaled_conv2d.in_channels,
|
||||
out_channels=scaled_conv2d.out_channels,
|
||||
kernel_size=scaled_conv2d.kernel_size,
|
||||
stride=scaled_conv2d.stride,
|
||||
padding=scaled_conv2d.padding,
|
||||
dilation=scaled_conv2d.dilation,
|
||||
groups=scaled_conv2d.groups,
|
||||
bias=scaled_conv2d.bias is not None,
|
||||
padding_mode=scaled_conv2d.padding_mode,
|
||||
)
|
||||
|
||||
conv2d.weight.data.copy_(weight)
|
||||
if has_bias:
|
||||
conv2d.bias.data.copy_(bias)
|
||||
|
||||
return conv2d
|
||||
|
||||
|
||||
def scaled_embedding_to_embedding(
|
||||
scaled_embedding: ScaledEmbedding,
|
||||
) -> nn.Embedding:
|
||||
"""Convert an instance of ScaledEmbedding to nn.Embedding.
|
||||
|
||||
Args:
|
||||
scaled_embedding:
|
||||
The layer to be converted.
|
||||
Returns:
|
||||
Return an instance of nn.Embedding that has the same `forward()` behavior
|
||||
of the given `scaled_embedding`.
|
||||
"""
|
||||
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
|
||||
embedding = nn.Embedding(
|
||||
num_embeddings=scaled_embedding.num_embeddings,
|
||||
embedding_dim=scaled_embedding.embedding_dim,
|
||||
padding_idx=scaled_embedding.padding_idx,
|
||||
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
|
||||
sparse=scaled_embedding.sparse,
|
||||
)
|
||||
weight = scaled_embedding.weight
|
||||
scale = scaled_embedding.scale
|
||||
|
||||
embedding.weight.data.copy_(weight * scale.exp())
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
||||
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
||||
and `nn.Conv2d`.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The model to be converted.
|
||||
inplace:
|
||||
If True, the input model is modified inplace.
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
excluded_patterns = r"self_attn\.(in|out)_proj"
|
||||
p = re.compile(excluded_patterns)
|
||||
|
||||
d = {}
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, ScaledLinear):
|
||||
if p.search(name) is not None:
|
||||
continue
|
||||
d[name] = scaled_linear_to_linear(m)
|
||||
elif isinstance(m, ScaledConv1d):
|
||||
d[name] = scaled_conv1d_to_conv1d(m)
|
||||
elif isinstance(m, ScaledConv2d):
|
||||
d[name] = scaled_conv2d_to_conv2d(m)
|
||||
elif isinstance(m, ScaledEmbedding):
|
||||
d[name] = scaled_embedding_to_embedding(m)
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
parent, child = k.rsplit(".", maxsplit=1)
|
||||
setattr(model.get_submodule(parent), child, v)
|
||||
else:
|
||||
setattr(model, k, v)
|
||||
|
||||
return model
|
@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless3/test_scaling_converter.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
|
||||
from scaling_converter import (
|
||||
convert_scaled_to_non_scaled,
|
||||
scaled_conv1d_to_conv1d,
|
||||
scaled_conv2d_to_conv2d,
|
||||
scaled_embedding_to_embedding,
|
||||
scaled_linear_to_linear,
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
params.dynamic_chunk_training = False
|
||||
params.short_chunk_size = 25
|
||||
params.num_left_chunks = 4
|
||||
params.causal_convolution = False
|
||||
|
||||
model = get_transducer_model(params, enable_giga=False)
|
||||
return model
|
||||
|
||||
|
||||
def test_scaled_linear_to_linear():
|
||||
N = 5
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
for bias in [True, False]:
|
||||
scaled_linear = ScaledLinear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
)
|
||||
linear = scaled_linear_to_linear(scaled_linear)
|
||||
x = torch.rand(N, in_features)
|
||||
|
||||
y1 = scaled_linear(x)
|
||||
y2 = linear(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
|
||||
jit_scaled_linear = torch.jit.script(scaled_linear)
|
||||
jit_linear = torch.jit.script(linear)
|
||||
|
||||
y3 = jit_scaled_linear(x)
|
||||
y4 = jit_linear(x)
|
||||
|
||||
assert torch.allclose(y3, y4)
|
||||
assert torch.allclose(y1, y4)
|
||||
|
||||
|
||||
def test_scaled_conv1d_to_conv1d():
|
||||
in_channels = 3
|
||||
for bias in [True, False]:
|
||||
scaled_conv1d = ScaledConv1d(
|
||||
in_channels,
|
||||
6,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
|
||||
|
||||
x = torch.rand(20, in_channels, 10)
|
||||
y1 = scaled_conv1d(x)
|
||||
y2 = conv1d(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
|
||||
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
|
||||
jit_conv1d = torch.jit.script(conv1d)
|
||||
|
||||
y3 = jit_scaled_conv1d(x)
|
||||
y4 = jit_conv1d(x)
|
||||
|
||||
assert torch.allclose(y3, y4)
|
||||
assert torch.allclose(y1, y4)
|
||||
|
||||
|
||||
def test_scaled_conv2d_to_conv2d():
|
||||
in_channels = 1
|
||||
for bias in [True, False]:
|
||||
scaled_conv2d = ScaledConv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
|
||||
|
||||
x = torch.rand(20, in_channels, 10, 20)
|
||||
y1 = scaled_conv2d(x)
|
||||
y2 = conv2d(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
|
||||
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
|
||||
jit_conv2d = torch.jit.script(conv2d)
|
||||
|
||||
y3 = jit_scaled_conv2d(x)
|
||||
y4 = jit_conv2d(x)
|
||||
|
||||
assert torch.allclose(y3, y4)
|
||||
assert torch.allclose(y1, y4)
|
||||
|
||||
|
||||
def test_scaled_embedding_to_embedding():
|
||||
scaled_embedding = ScaledEmbedding(
|
||||
num_embeddings=500,
|
||||
embedding_dim=10,
|
||||
padding_idx=0,
|
||||
)
|
||||
embedding = scaled_embedding_to_embedding(scaled_embedding)
|
||||
|
||||
for s in [10, 100, 300, 500, 800, 1000]:
|
||||
x = torch.randint(low=0, high=500, size=(s,))
|
||||
scaled_y = scaled_embedding(x)
|
||||
y = embedding(x)
|
||||
assert torch.equal(scaled_y, y)
|
||||
|
||||
|
||||
def test_convert_scaled_to_non_scaled():
|
||||
for inplace in [False, True]:
|
||||
model = get_model()
|
||||
model.eval()
|
||||
|
||||
orig_model = copy.deepcopy(model)
|
||||
|
||||
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
|
||||
|
||||
model = orig_model
|
||||
|
||||
# test encoder
|
||||
N = 2
|
||||
T = 100
|
||||
vocab_size = model.decoder.vocab_size
|
||||
|
||||
x = torch.randn(N, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((N,), x.size(1))
|
||||
|
||||
e1, e1_lens = model.encoder(x, x_lens)
|
||||
e2, e2_lens = converted_model.encoder(x, x_lens)
|
||||
|
||||
assert torch.all(torch.eq(e1_lens, e2_lens))
|
||||
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
|
||||
|
||||
# test decoder
|
||||
U = 50
|
||||
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
|
||||
|
||||
d1 = model.decoder(y)
|
||||
d2 = model.decoder(y)
|
||||
|
||||
assert torch.allclose(d1, d2)
|
||||
|
||||
# test simple projection
|
||||
lm1 = model.simple_lm_proj(d1)
|
||||
am1 = model.simple_am_proj(e1)
|
||||
|
||||
lm2 = converted_model.simple_lm_proj(d2)
|
||||
am2 = converted_model.simple_am_proj(e2)
|
||||
|
||||
assert torch.allclose(lm1, lm2)
|
||||
assert torch.allclose(am1, am2)
|
||||
|
||||
# test joiner
|
||||
e = torch.rand(2, 3, 4, 512)
|
||||
d = torch.rand(2, 3, 4, 512)
|
||||
|
||||
j1 = model.joiner(e, d)
|
||||
j2 = converted_model.joiner(e, d)
|
||||
assert torch.allclose(j1, j2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
test_scaled_linear_to_linear()
|
||||
test_scaled_conv1d_to_conv1d()
|
||||
test_scaled_conv2d_to_conv2d()
|
||||
test_scaled_embedding_to_embedding()
|
||||
test_convert_scaled_to_non_scaled()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20220730)
|
||||
main()
|
@ -436,13 +436,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
def get_transducer_model(
|
||||
params: AttributeDict,
|
||||
enable_giga: bool = True,
|
||||
) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
decoder_giga = get_decoder_model(params)
|
||||
joiner_giga = get_joiner_model(params)
|
||||
if enable_giga:
|
||||
logging.info("Use giga")
|
||||
decoder_giga = get_decoder_model(params)
|
||||
joiner_giga = get_joiner_model(params)
|
||||
else:
|
||||
logging.info("Disable giga")
|
||||
decoder_giga = None
|
||||
joiner_giga = None
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
@ -1049,14 +1058,15 @@ def run(rank, world_size, args):
|
||||
# It's time consuming to include `giga_train_dl` here
|
||||
# for dl in [train_dl, giga_train_dl]:
|
||||
for dl in [train_dl]:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
if params.start_batch <= 0:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
|
@ -525,9 +525,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -757,13 +754,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -805,7 +796,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -818,7 +808,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
@ -993,7 +982,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
|
@ -550,9 +550,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -782,13 +779,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -834,7 +825,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -847,7 +837,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
@ -1025,7 +1014,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
|
@ -507,9 +507,6 @@ def load_checkpoint_if_available(
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -763,13 +760,7 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -811,7 +802,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -824,7 +814,6 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
@ -999,7 +988,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
if params.start_batch <= 0 and not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
|
@ -430,6 +430,17 @@ def compute_loss(
|
||||
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = supervisions["num_frames"].sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -349,6 +349,17 @@ def compute_loss(
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = supervisions["num_frames"].sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -403,6 +403,15 @@ def compute_loss(
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -407,6 +407,15 @@ def compute_loss(
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -429,6 +429,15 @@ def compute_loss(
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -417,6 +417,15 @@ def compute_loss(
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -476,6 +476,15 @@ def compute_loss(
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
|
||||
info["utterances"] = feature.size(0)
|
||||
# averaged input duration in frames over utterances
|
||||
info["utt_duration"] = feature_lens.sum().item()
|
||||
# averaged padding proportion over utterances
|
||||
info["utt_pad_proportion"] = (
|
||||
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
|
||||
)
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
# 2022 Xiaomi Corp. (authors: Weiji Zhuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -29,10 +30,18 @@ with word segmenting:
|
||||
|
||||
|
||||
import argparse
|
||||
from multiprocessing import Pool
|
||||
|
||||
import jieba
|
||||
import paddle
|
||||
from tqdm import tqdm
|
||||
|
||||
# In PaddlePaddle 2.x, dynamic graph mode is turned on by default,
|
||||
# and 'data()' is only supported in static graph mode. So if you
|
||||
# want to use this api, should call 'paddle.enable_static()' before
|
||||
# this api to enter static graph mode.
|
||||
paddle.enable_static()
|
||||
paddle.disable_signal_handler()
|
||||
jieba.enable_paddle()
|
||||
|
||||
|
||||
@ -41,14 +50,23 @@ def get_parser():
|
||||
description="Chinese Word Segmentation for text",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-process",
|
||||
"-n",
|
||||
default=20,
|
||||
type=int,
|
||||
help="the number of processes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-file",
|
||||
"-i",
|
||||
default="data/lang_char/text",
|
||||
type=str,
|
||||
help="the input text file for WenetSpeech",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
"-o",
|
||||
default="data/lang_char/text_words_segmentation",
|
||||
type=str,
|
||||
help="the text implemented with words segmenting for WenetSpeech",
|
||||
@ -57,26 +75,33 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def cut(lines):
|
||||
if lines is not None:
|
||||
cut_lines = jieba.cut(lines, use_paddle=True)
|
||||
return [i for i in cut_lines]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
num_process = args.num_process
|
||||
input_file = args.input_file
|
||||
output_file = args.output_file
|
||||
# parallel mode does not support use_paddle
|
||||
# jieba.enable_parallel(num_process)
|
||||
|
||||
f = open(input_file, "r", encoding="utf-8")
|
||||
lines = f.readlines()
|
||||
new_lines = []
|
||||
for i in tqdm(range(len(lines))):
|
||||
x = lines[i].rstrip()
|
||||
seg_list = jieba.cut(x, use_paddle=True)
|
||||
new_line = " ".join(seg_list)
|
||||
new_lines.append(new_line)
|
||||
with open(input_file, "r", encoding="utf-8") as fr:
|
||||
lines = fr.readlines()
|
||||
|
||||
f_new = open(output_file, "w", encoding="utf-8")
|
||||
for line in new_lines:
|
||||
f_new.write(line)
|
||||
f_new.write("\n")
|
||||
with Pool(processes=num_process) as p:
|
||||
new_lines = list(tqdm(p.imap(cut, lines), total=len(lines)))
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as fw:
|
||||
for line in new_lines:
|
||||
fw.write(" ".join(line) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -28,6 +28,7 @@ num_splits=1000
|
||||
# - speech
|
||||
|
||||
dl_dir=$PWD/download
|
||||
lang_char_dir=data/lang_char
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
@ -186,24 +187,27 @@ fi
|
||||
|
||||
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
log "Stage 15: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
# Prepare text.
|
||||
# Note: in Linux, you can install jq with the following command:
|
||||
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
||||
# 2. chmod +x ./jq
|
||||
# 3. cp jq /usr/bin
|
||||
if [ ! -f $lang_char_dir/text ]; then
|
||||
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
||||
| jq 'text' | sed 's/"//g' \
|
||||
if ! which jq; then
|
||||
echo "This script is intended to be used with jq but you have not installed jq
|
||||
Note: in Linux, you can install jq with the following command:
|
||||
1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
||||
2. chmod +x ./jq
|
||||
3. cp jq /usr/bin" && exit 1
|
||||
fi
|
||||
if [ ! -f $lang_char_dir/text ] || [ ! -s $lang_char_dir/text ]; then
|
||||
log "Prepare text."
|
||||
gunzip -c data/manifests/wenetspeech_supervisions_L.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||
fi
|
||||
|
||||
# The implementation of chinese word segmentation for text,
|
||||
# and it will take about 15 minutes.
|
||||
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
|
||||
python ./local/text2segments.py \
|
||||
python3 ./local/text2segments.py \
|
||||
--num-process $nj \
|
||||
--input-file $lang_char_dir/text \
|
||||
--output-file $lang_char_dir/text_words_segmentation
|
||||
fi
|
||||
@ -212,7 +216,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
| sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt
|
||||
|
||||
if [ ! -f $lang_char_dir/words.txt ]; then
|
||||
python ./local/prepare_words.py \
|
||||
python3 ./local/prepare_words.py \
|
||||
--input-file $lang_char_dir/words_no_ids.txt \
|
||||
--output-file $lang_char_dir/words.txt
|
||||
fi
|
||||
@ -221,7 +225,7 @@ fi
|
||||
if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
|
||||
log "Stage 16: Prepare char based L_disambig.pt"
|
||||
if [ ! -f data/lang_char/L_disambig.pt ]; then
|
||||
python ./local/prepare_char.py \
|
||||
python3 ./local/prepare_char.py \
|
||||
--lang-dir data/lang_char
|
||||
fi
|
||||
fi
|
||||
@ -232,9 +236,8 @@ if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||
# It will take about 20 minutes.
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
lang_char_dir=data/lang_char
|
||||
if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then
|
||||
python ./shared/make_kn_lm.py \
|
||||
python3 ./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text $lang_char_dir/text_words_segmentation \
|
||||
-lm $lang_char_dir/3-gram.unpruned.arpa
|
||||
@ -253,6 +256,5 @@ fi
|
||||
|
||||
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||
log "Stage 18: Compile LG"
|
||||
lang_char_dir=data/lang_char
|
||||
python ./local/compile_lg.py --lang-dir $lang_char_dir
|
||||
fi
|
||||
|
@ -1144,8 +1144,6 @@ def display_and_save_batch(
|
||||
y = graph_compiler.texts_to_ids(texts)
|
||||
if type(y) == list:
|
||||
y = k2.RaggedTensor(y)
|
||||
else:
|
||||
y = y
|
||||
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
@ -330,12 +330,17 @@ class Nbest(object):
|
||||
# We use a word fsa to intersect with k2.invert(lattice)
|
||||
word_fsa = k2.invert(self.fsa)
|
||||
|
||||
word_fsa.scores.zero_()
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# delete token IDs as it is not needed
|
||||
del word_fsa.aux_labels
|
||||
|
||||
word_fsa.scores.zero_()
|
||||
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
||||
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
|
||||
word_fsa
|
||||
)
|
||||
else:
|
||||
word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
|
||||
word_fsa
|
||||
)
|
||||
|
||||
path_to_utt_map = self.shape.row_ids(1)
|
||||
|
||||
|
@ -544,9 +544,10 @@ class MetricsTracker(collections.defaultdict):
|
||||
else:
|
||||
raise ValueError(f"Unexpected key: {k}")
|
||||
frames = "%.2f" % self["frames"]
|
||||
ans_frames += "over " + str(frames) + " frames; "
|
||||
utterances = "%.2f" % self["utterances"]
|
||||
ans_utterances += "over " + str(utterances) + " utterances."
|
||||
ans_frames += "over " + str(frames) + " frames. "
|
||||
if ans_utterances != "":
|
||||
utterances = "%.2f" % self["utterances"]
|
||||
ans_utterances += "over " + str(utterances) + " utterances."
|
||||
|
||||
return ans_frames + ans_utterances
|
||||
|
||||
|
@ -20,3 +20,7 @@ sentencepiece==0.1.96
|
||||
tensorboard==2.8.0
|
||||
typeguard==2.13.3
|
||||
multi_quantization
|
||||
|
||||
onnx
|
||||
onnxruntime
|
||||
onnx_graphsurgeon -i https://pypi.ngc.nvidia.com
|
||||
|
@ -4,3 +4,7 @@ sentencepiece>=0.1.96
|
||||
tensorboard
|
||||
typeguard
|
||||
multi_quantization
|
||||
onnx
|
||||
onnxruntime
|
||||
--extra-index-url https://pypi.ngc.nvidia.com
|
||||
onnx_graphsurgeon
|
||||
|
Loading…
x
Reference in New Issue
Block a user