mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Decrease the model size and other fixes
This commit is contained in:
parent
ea6b7c5160
commit
ff7af3586a
@ -296,7 +296,7 @@ def beam_search(
|
|||||||
if cached_key not in joint_cache:
|
if cached_key not in joint_cache:
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(current_encoder_out, decoder_out)
|
||||||
|
|
||||||
# TODO(fangjun): Ccale the blank posterior
|
# TODO(fangjun): Scale the blank posterior
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
log_prob = logits.log_softmax(dim=-1)
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
# log_prob is (1, 1, 1, vocab_size)
|
||||||
|
@ -31,7 +31,6 @@ from decoder import Decoder
|
|||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
@ -39,8 +38,8 @@ from icefall.utils import (
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
write_error_stats,
|
|
||||||
str2bool,
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -130,9 +129,9 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"embedding_dim": 256,
|
"embedding_dim": 256,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 512,
|
"attention_dim": 256,
|
||||||
"nhead": 8,
|
"nhead": 4,
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 1024,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -141,7 +140,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
@ -156,7 +155,7 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.embedding_dim,
|
embedding_dim=params.embedding_dim,
|
||||||
@ -166,16 +165,16 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.vocab_size,
|
input_dim=params.vocab_size,
|
||||||
output_dim=params.vocab_size,
|
|
||||||
inner_dim=params.embedding_dim,
|
inner_dim=params.embedding_dim,
|
||||||
|
output_dim=params.vocab_size,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
@ -404,10 +403,6 @@ def main():
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
|
||||||
lexicon=lexicon,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
params.blank_id = 0
|
params.blank_id = 0
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
@ -96,5 +96,5 @@ class Decoder(nn.Module):
|
|||||||
assert embeding_out.size(-1) == self.context_size
|
assert embeding_out.size(-1) == self.context_size
|
||||||
embeding_out = self.conv(embeding_out)
|
embeding_out = self.conv(embeding_out)
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embeding_out = embeding_out.permute(0, 2, 1)
|
||||||
embeding_out = self.output_linear(embeding_out)
|
embeding_out = self.output_linear(F.relu(embeding_out))
|
||||||
return embeding_out
|
return embeding_out
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer_stateless/export.py \
|
./transducer_stateless/export.py \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--lang-dir data/lang_char \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -39,15 +39,15 @@ To use the generated file with `transducer_stateless/decode.py`, you can do:
|
|||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--max-duration 1 \
|
--max-duration 1 \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model
|
--lang-dir data/lang_char
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -55,6 +55,7 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -90,10 +91,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_char",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -120,11 +121,11 @@ def get_params() -> AttributeDict:
|
|||||||
{
|
{
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"embedding_dim": 256,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 512,
|
"attention_dim": 256,
|
||||||
"nhead": 8,
|
"nhead": 4,
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 1024,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -133,10 +134,10 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.vocab_size,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
d_model=params.attention_dim,
|
d_model=params.attention_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
@ -147,25 +148,26 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.vocab_size,
|
||||||
|
inner_dim=params.embedding_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
@ -193,12 +195,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
lexicon = Lexicon(params.lang_dir)
|
||||||
sp.load(params.bpe_model)
|
params.blank_id = 0
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -22,9 +22,8 @@ class Joiner(nn.Module):
|
|||||||
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.output_linear = nn.Sequential(
|
self.inner_linear = nn.Linear(input_dim, inner_dim)
|
||||||
nn.Linear(input_dim, inner_dim), nn.Linear(inner_dim, output_dim)
|
self.output_linear = nn.Linear(inner_dim, output_dim)
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
@ -32,16 +31,19 @@ class Joiner(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C).
|
The pruned output from the encoder. Its shape is (N, T, s_range, C).
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, U, C).
|
The pruned output from the decoder. Its shape is (N, T, s_range, C).
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, U, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||||
assert encoder_out.shape == decoder_out.shape
|
assert encoder_out.shape == decoder_out.shape
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
|
logit = self.inner_linear(logit)
|
||||||
|
|
||||||
logit = torch.tanh(logit)
|
logit = torch.tanh(logit)
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
output = self.output_linear(logit)
|
||||||
|
@ -32,7 +32,7 @@ class Transducer(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
prune_range: int = 5,
|
prune_range: int = 3,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
):
|
):
|
||||||
@ -51,6 +51,20 @@ class Transducer(nn.Module):
|
|||||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||||
output shape is (N, T, U, C). Note that its output contains
|
output shape is (N, T, U, C). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
|
prune_range:
|
||||||
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
|
we are considering for each frame to compute the loss.
|
||||||
|
am_scale:
|
||||||
|
The scale to smooth the loss with am (output of encoder network)
|
||||||
|
part
|
||||||
|
lm_scale:
|
||||||
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
|
part
|
||||||
|
Note:
|
||||||
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
|
the form:
|
||||||
|
lm_scale * lm_probs + am_scale * am_probs +
|
||||||
|
(1-lm_scale-am_scale) * combined_probs
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--lang-dir ./data/lang_char \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav \
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(1) beam search
|
(1) beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--lang-dir ./data/lang_char \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -58,9 +59,8 @@ from model import Transducer
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.utils import AttributeDict
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -137,11 +137,11 @@ def get_params() -> AttributeDict:
|
|||||||
"sample_rate": 16000,
|
"sample_rate": 16000,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"embedding_dim": 256,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 512,
|
"attention_dim": 256,
|
||||||
"nhead": 8,
|
"nhead": 4,
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 1024,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -150,10 +150,10 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.vocab_size,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
d_model=params.attention_dim,
|
d_model=params.attention_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
@ -164,25 +164,26 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.vocab_size,
|
||||||
|
inner_dim=params.embedding_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
@ -235,12 +236,8 @@ def main():
|
|||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
|
||||||
lexicon=lexicon,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
params.blank_id = 0
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
logging.info("Creating model")
|
logging.info("Creating model")
|
||||||
|
@ -42,12 +42,12 @@ def test_decoder():
|
|||||||
U = 20
|
U = 20
|
||||||
x = torch.randint(low=0, high=vocab_size, size=(N, U))
|
x = torch.randint(low=0, high=vocab_size, size=(N, U))
|
||||||
y = decoder(x)
|
y = decoder(x)
|
||||||
assert y.shape == (N, U, embedding_dim)
|
assert y.shape == (N, U, vocab_size)
|
||||||
|
|
||||||
# for inference
|
# for inference
|
||||||
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
||||||
y = decoder(x, need_pad=False)
|
y = decoder(x, need_pad=False)
|
||||||
assert y.shape == (N, 1, embedding_dim)
|
assert y.shape == (N, 1, vocab_size)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -131,7 +131,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=3,
|
||||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
"we are using to compute the loss",
|
"we are using to compute the loss",
|
||||||
)
|
)
|
||||||
@ -139,7 +139,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lm-scale",
|
"--lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.5,
|
||||||
help="The scale to smooth the loss with lm "
|
help="The scale to smooth the loss with lm "
|
||||||
"(output of prediction network) part.",
|
"(output of prediction network) part.",
|
||||||
)
|
)
|
||||||
@ -212,9 +212,9 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 512,
|
"attention_dim": 256,
|
||||||
"nhead": 8,
|
"nhead": 4,
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 1024,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
@ -228,7 +228,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
@ -243,7 +243,7 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.embedding_dim,
|
embedding_dim=params.embedding_dim,
|
||||||
@ -253,7 +253,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.vocab_size,
|
input_dim=params.vocab_size,
|
||||||
inner_dim=params.embedding_dim,
|
inner_dim=params.embedding_dim,
|
||||||
@ -262,7 +262,7 @@ def get_joiner_model(params: AttributeDict):
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user