mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Merge da98aa1b8ea4a7e318b0c91702724cc853bb08ff into 8e6fd97c6b92826b2e13caa1a683f2eb4bf9a832
This commit is contained in:
commit
bff66ce86a
@ -51,6 +51,7 @@ from conformer import Conformer
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
from transformer import Transformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
@ -128,6 +129,13 @@ def get_parser():
|
|||||||
help="Maximum number of symbols per frame",
|
help="Maximum number of symbols per frame",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-type",
|
||||||
|
type=str,
|
||||||
|
default="conformer",
|
||||||
|
help="Type of the encoder. Valid values are: conformer and transformer",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -150,8 +158,16 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict):
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
if params.encoder_type == "conformer":
|
||||||
encoder = Conformer(
|
Encoder = Conformer
|
||||||
|
elif params.encoder_type == "transformer":
|
||||||
|
Encoder = Transformer
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported encoder type: {params.encoder_type}"
|
||||||
|
"\nPlease use conformer or transformer"
|
||||||
|
)
|
||||||
|
encoder = Encoder(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
@ -54,7 +54,7 @@ from torch import Tensor
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam, Transformer
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
@ -138,6 +138,13 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-type",
|
||||||
|
type=str,
|
||||||
|
default="conformer",
|
||||||
|
help="Type of the encoder. Valid values are: conformer and transformer",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -214,8 +221,16 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict):
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
if params.encoder_type == "conformer":
|
||||||
encoder = Conformer(
|
Encoder = Conformer
|
||||||
|
elif params.encoder_type == "transformer":
|
||||||
|
Encoder = Transformer
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported encoder type: {params.encoder_type}"
|
||||||
|
"\nPlease use conformer or transformer"
|
||||||
|
)
|
||||||
|
encoder = Encoder(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
@ -360,7 +375,8 @@ def compute_loss(
|
|||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The model for training. It is an instance of Conformer in our case.
|
The model for training. It is an instance of Conformer or Transformer
|
||||||
|
in our case.
|
||||||
batch:
|
batch:
|
||||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
for the content in it.
|
for the content in it.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user