mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
do some changes for pruned RNN-T
This commit is contained in:
parent
b7e7629168
commit
cd6a2d903b
@ -483,8 +483,9 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
@ -37,6 +37,7 @@ class Decoder(nn.Module):
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
|
unk_id: int,
|
||||||
context_size: int,
|
context_size: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -47,6 +48,8 @@ class Decoder(nn.Module):
|
|||||||
Dimension of the input embedding.
|
Dimension of the input embedding.
|
||||||
blank_id:
|
blank_id:
|
||||||
The ID of the blank symbol.
|
The ID of the blank symbol.
|
||||||
|
unk_id:
|
||||||
|
The ID of the unk symbol.
|
||||||
context_size:
|
context_size:
|
||||||
Number of previous words to use to predict the next word.
|
Number of previous words to use to predict the next word.
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
@ -58,6 +61,7 @@ class Decoder(nn.Module):
|
|||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
self.unk_id = unk_id
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
|
@ -319,6 +319,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.embedding_dim,
|
embedding_dim=params.embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
|
unk_id=params.unk_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
@ -756,8 +757,9 @@ def run(rank, world_size, args):
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
@ -12,7 +12,7 @@ The WERs are
|
|||||||
|------------------------------------|------------|------------|------------------------------------------|
|
|------------------------------------|------------|------------|------------------------------------------|
|
||||||
| greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 |
|
| greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 |
|
||||||
| beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 |
|
| beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 |
|
||||||
| modified beam search (beam size 4) | 6.72 | 6.12 | --epoch 29, --avg 13, --max-duration 100 |
|
| modified beam search (beam size 4) | 6.77 | 6.12 | --epoch 29, --avg 13, --max-duration 100 |
|
||||||
| fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500|
|
| fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500|
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang
|
# Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
# Mingshuang Luo)
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -369,13 +369,158 @@ class HypothesisList(object):
|
|||||||
return ", ".join(s)
|
return ", ".join(s)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||||
|
"""Return a ragged shape with axes [utt][num_hyps].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyps:
|
||||||
|
len(hyps) == batch_size. It contains the current hypothesis for
|
||||||
|
each utterance in the batch.
|
||||||
|
Returns:
|
||||||
|
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
|
||||||
|
the shape is on CPU.
|
||||||
|
"""
|
||||||
|
num_hyps = [len(h) for h in hyps]
|
||||||
|
|
||||||
|
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
|
||||||
|
# to get exclusive sum later.
|
||||||
|
num_hyps.insert(0, 0)
|
||||||
|
|
||||||
|
num_hyps = torch.tensor(num_hyps)
|
||||||
|
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
|
||||||
|
ans = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def modified_beam_search(
|
def modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
beam:
|
||||||
|
Number of active paths during the beam search.
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||||
|
for the i-th utterance.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = model.decoder.unk_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = model.device
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
for i in range(batch_size):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
hyps_shape = _get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
) # (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
|
||||||
|
|
||||||
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
|
# as index, so we use `to(torch.int64)` below.
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(
|
||||||
|
shape=log_probs_shape, value=log_probs
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
|
||||||
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
if new_token != blank_id and new_token != unk_id:
|
||||||
|
new_ys.append(new_token)
|
||||||
|
|
||||||
|
new_log_prob = topk_log_probs[k]
|
||||||
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _deprecated_modified_beam_search(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
It decodes only one utterance at a time. We keep it only for reference.
|
||||||
|
The function :func:`modified_beam_search` should be preferred as it
|
||||||
|
supports batch decoding.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
|
@ -74,13 +74,9 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from conformer import Conformer
|
from train import get_params, get_transducer_model
|
||||||
from decoder import Decoder
|
|
||||||
from joiner import Joiner
|
|
||||||
from model import Transducer
|
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
@ -182,7 +178,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=1,
|
||||||
help="""Maximum number of symbols per frame.
|
help="""Maximum number of symbols per frame.
|
||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
@ -190,73 +186,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# parameters for conformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
# parameters for decoder
|
|
||||||
"embedding_dim": 512,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
|
||||||
encoder = Conformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
embedding_dim=params.embedding_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
unk_id=params.unk_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
|
||||||
joiner = Joiner(
|
|
||||||
input_dim=params.vocab_size,
|
|
||||||
inner_dim=params.embedding_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = Transducer(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -329,6 +258,14 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -348,12 +285,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
|
||||||
hyp = modified_beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
@ -593,8 +524,5 @@ def main():
|
|||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -1,105 +0,0 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
|
||||||
# Mingshuang Luo)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
"""This class modifies the stateless decoder from the following paper:
|
|
||||||
|
|
||||||
RNN-transducer with stateless prediction network
|
|
||||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
|
||||||
|
|
||||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
|
||||||
network. Different from the above paper, it adds an extra Conv1d
|
|
||||||
right after the embedding layer.
|
|
||||||
|
|
||||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
blank_id: int,
|
|
||||||
unk_id: int,
|
|
||||||
context_size: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
vocab_size:
|
|
||||||
Number of tokens of the modeling unit including blank.
|
|
||||||
embedding_dim:
|
|
||||||
Dimension of the input embedding.
|
|
||||||
blank_id:
|
|
||||||
The ID of the blank symbol.
|
|
||||||
unk_id:
|
|
||||||
The ID of the unk symbol.
|
|
||||||
context_size:
|
|
||||||
Number of previous words to use to predict the next word.
|
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.embedding = nn.Embedding(
|
|
||||||
num_embeddings=vocab_size,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
padding_idx=blank_id,
|
|
||||||
)
|
|
||||||
self.blank_id = blank_id
|
|
||||||
self.unk_id = unk_id
|
|
||||||
assert context_size >= 1, context_size
|
|
||||||
self.context_size = context_size
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
if context_size > 1:
|
|
||||||
self.conv = nn.Conv1d(
|
|
||||||
in_channels=embedding_dim,
|
|
||||||
out_channels=embedding_dim,
|
|
||||||
kernel_size=context_size,
|
|
||||||
padding=0,
|
|
||||||
groups=embedding_dim,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.output_linear = nn.Linear(embedding_dim, vocab_size)
|
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
y:
|
|
||||||
A 2-D tensor of shape (N, U).
|
|
||||||
need_pad:
|
|
||||||
True to left pad the input. Should be True during training.
|
|
||||||
False to not pad the input. Should be False during inference.
|
|
||||||
Returns:
|
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
|
||||||
"""
|
|
||||||
embedding_out = self.embedding(y)
|
|
||||||
if self.context_size > 1:
|
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
|
||||||
if need_pad is True:
|
|
||||||
embedding_out = F.pad(
|
|
||||||
embedding_out, pad=(self.context_size - 1, 0)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# During inference time, there is no need to do extra padding
|
|
||||||
# as we only need one output
|
|
||||||
assert embedding_out.size(-1) == self.context_size
|
|
||||||
embedding_out = self.conv(embedding_out)
|
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
|
||||||
embedding_out = self.output_linear(F.relu(embedding_out))
|
|
||||||
return embedding_out
|
|
1
egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py
Symbolic link
1
egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless/decoder.py
|
@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
||||||
# Mingshuang Luo)
|
# Mingshuang Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -50,15 +50,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from train import get_params, get_transducer_model
|
||||||
from conformer import Conformer
|
|
||||||
from decoder import Decoder
|
|
||||||
from joiner import Joiner
|
|
||||||
from model import Transducer
|
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.utils import str2bool
|
||||||
from icefall.utils import AttributeDict, str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -69,7 +64,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=30,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
@ -77,7 +72,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=13,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -118,73 +113,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# parameters for conformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"encoder_out_dim": 512,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
# parameters for decoder
|
|
||||||
"embedding_dim": 512,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = Conformer(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_dim=params.encoder_out_dim,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
embedding_dim=params.encoder_out_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
unk_id=params.unk_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|
||||||
joiner = Joiner(
|
|
||||||
input_dim=params.encoder_out_dim,
|
|
||||||
inner_dim=params.embedding_dim,
|
|
||||||
output_dim=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = Transducer(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user