mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Decrease the model size and other fixes
This commit is contained in:
parent
ea6b7c5160
commit
ff7af3586a
@ -296,7 +296,7 @@ def beam_search(
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
|
||||
# TODO(fangjun): Ccale the blank posterior
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
|
@ -31,7 +31,6 @@ from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from model import Transducer
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
@ -39,8 +38,8 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
@ -130,9 +129,9 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"embedding_dim": 256,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"attention_dim": 256,
|
||||
"nhead": 4,
|
||||
"dim_feedforward": 1024,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"env_info": get_env_info(),
|
||||
@ -141,7 +140,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
@ -156,7 +155,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.embedding_dim,
|
||||
@ -166,16 +165,16 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.vocab_size,
|
||||
output_dim=params.vocab_size,
|
||||
inner_dim=params.embedding_dim,
|
||||
output_dim=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
@ -404,10 +403,6 @@ def main():
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
@ -96,5 +96,5 @@ class Decoder(nn.Module):
|
||||
assert embeding_out.size(-1) == self.context_size
|
||||
embeding_out = self.conv(embeding_out)
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
embeding_out = self.output_linear(embeding_out)
|
||||
embeding_out = self.output_linear(F.relu(embeding_out))
|
||||
return embeding_out
|
||||
|
@ -22,7 +22,7 @@
|
||||
Usage:
|
||||
./transducer_stateless/export.py \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--lang-dir data/lang_char \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
@ -39,15 +39,15 @@ To use the generated file with `transducer_stateless/decode.py`, you can do:
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 1 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
--lang-dir data/lang_char
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -55,6 +55,7 @@ from model import Transducer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
@ -90,10 +91,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
default="data/lang_char",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -120,11 +121,11 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
"embedding_dim": 256,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"attention_dim": 256,
|
||||
"nhead": 4,
|
||||
"dim_feedforward": 1024,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"env_info": get_env_info(),
|
||||
@ -133,10 +134,10 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
@ -147,25 +148,26 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
embedding_dim=params.embedding_dim,
|
||||
blank_id=params.blank_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
input_dim=params.vocab_size,
|
||||
inner_dim=params.embedding_dim,
|
||||
output_dim=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
@ -193,12 +195,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -22,9 +22,8 @@ class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.output_linear = nn.Sequential(
|
||||
nn.Linear(input_dim, inner_dim), nn.Linear(inner_dim, output_dim)
|
||||
)
|
||||
self.inner_linear = nn.Linear(input_dim, inner_dim)
|
||||
self.output_linear = nn.Linear(inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
@ -32,16 +31,19 @@ class Joiner(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
The pruned output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, U, C).
|
||||
The pruned output from the decoder. Its shape is (N, T, s_range, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, U, C).
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.inner_linear(logit)
|
||||
|
||||
logit = torch.tanh(logit)
|
||||
|
||||
output = self.output_linear(logit)
|
||||
|
@ -32,7 +32,7 @@ class Transducer(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
prune_range: int = 5,
|
||||
prune_range: int = 3,
|
||||
lm_scale: float = 0.0,
|
||||
am_scale: float = 0.0,
|
||||
):
|
||||
@ -51,6 +51,20 @@ class Transducer(nn.Module):
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
@ -20,7 +20,7 @@ Usage:
|
||||
(1) greedy search
|
||||
./transducer_stateless/pretrained.py \
|
||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
@ -28,7 +28,7 @@ Usage:
|
||||
(1) beam search
|
||||
./transducer_stateless/pretrained.py \
|
||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--lang-dir ./data/lang_char \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search
|
||||
from conformer import Conformer
|
||||
@ -58,9 +59,8 @@ from model import Transducer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -137,11 +137,11 @@ def get_params() -> AttributeDict:
|
||||
"sample_rate": 16000,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
"embedding_dim": 256,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"attention_dim": 256,
|
||||
"nhead": 4,
|
||||
"dim_feedforward": 1024,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"env_info": get_env_info(),
|
||||
@ -150,10 +150,10 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
@ -164,25 +164,26 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
embedding_dim=params.embedding_dim,
|
||||
blank_id=params.blank_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
input_dim=params.vocab_size,
|
||||
inner_dim=params.embedding_dim,
|
||||
output_dim=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
@ -235,12 +236,8 @@ def main():
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info("Creating model")
|
||||
|
@ -42,12 +42,12 @@ def test_decoder():
|
||||
U = 20
|
||||
x = torch.randint(low=0, high=vocab_size, size=(N, U))
|
||||
y = decoder(x)
|
||||
assert y.shape == (N, U, embedding_dim)
|
||||
assert y.shape == (N, U, vocab_size)
|
||||
|
||||
# for inference
|
||||
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
||||
y = decoder(x, need_pad=False)
|
||||
assert y.shape == (N, 1, embedding_dim)
|
||||
assert y.shape == (N, 1, vocab_size)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -131,7 +131,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--prune-range",
|
||||
type=int,
|
||||
default=5,
|
||||
default=3,
|
||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||
"we are using to compute the loss",
|
||||
)
|
||||
@ -139,7 +139,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
default=0.5,
|
||||
help="The scale to smooth the loss with lm "
|
||||
"(output of prediction network) part.",
|
||||
)
|
||||
@ -212,9 +212,9 @@ def get_params() -> AttributeDict:
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"attention_dim": 256,
|
||||
"nhead": 4,
|
||||
"dim_feedforward": 1024,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
# parameters for decoder
|
||||
@ -228,7 +228,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
@ -243,7 +243,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.embedding_dim,
|
||||
@ -253,7 +253,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.vocab_size,
|
||||
inner_dim=params.embedding_dim,
|
||||
@ -262,7 +262,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
Loading…
x
Reference in New Issue
Block a user