mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 14:14:19 +00:00
Use Emformer model as RNN-T encoder.
This commit is contained in:
parent
e867a62d32
commit
5728a4456e
104
egs/librispeech/ASR/transducer_emformer/noam.py
Normal file
104
egs/librispeech/ASR/transducer_emformer/noam.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Noam(object):
|
||||
"""
|
||||
Implements Noam optimizer.
|
||||
|
||||
Proposed in
|
||||
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
||||
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
|
||||
|
||||
Args:
|
||||
params:
|
||||
iterable of parameters to optimize or dicts defining parameter groups
|
||||
model_size:
|
||||
attention dimension of the transformer model
|
||||
factor:
|
||||
learning rate factor
|
||||
warm_step:
|
||||
warmup steps
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
model_size: int = 256,
|
||||
factor: float = 10.0,
|
||||
warm_step: int = 25000,
|
||||
weight_decay=0,
|
||||
) -> None:
|
||||
"""Construct an Noam object."""
|
||||
self.optimizer = torch.optim.Adam(
|
||||
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
||||
)
|
||||
self._step = 0
|
||||
self.warmup = warm_step
|
||||
self.factor = factor
|
||||
self.model_size = model_size
|
||||
self._rate = 0
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
"""Return param_groups."""
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def step(self):
|
||||
"""Update parameters and rate."""
|
||||
self._step += 1
|
||||
rate = self.rate()
|
||||
for p in self.optimizer.param_groups:
|
||||
p["lr"] = rate
|
||||
self._rate = rate
|
||||
self.optimizer.step()
|
||||
|
||||
def rate(self, step=None):
|
||||
"""Implement `lrate` above."""
|
||||
if step is None:
|
||||
step = self._step
|
||||
return (
|
||||
self.factor
|
||||
* self.model_size ** (-0.5)
|
||||
* min(step ** (-0.5), step * self.warmup ** (-1.5))
|
||||
)
|
||||
|
||||
def zero_grad(self):
|
||||
"""Reset gradient."""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def state_dict(self):
|
||||
"""Return state_dict."""
|
||||
return {
|
||||
"_step": self._step,
|
||||
"warmup": self.warmup,
|
||||
"factor": self.factor,
|
||||
"model_size": self.model_size,
|
||||
"_rate": self._rate,
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load state_dict."""
|
||||
for key, value in state_dict.items():
|
||||
if key == "optimizer":
|
||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||
else:
|
||||
setattr(self, key, value)
|
@ -21,11 +21,11 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless/train.py \
|
||||
./transducer_emformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--exp-dir transducer_emformer/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
"""
|
||||
@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@ -43,18 +44,18 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from emformer import Emformer
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from noam import Noam
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
@ -111,7 +112,7 @@ def get_parser():
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
transducer_stateless/exp/epoch-{start_epoch-1}.pt
|
||||
transducer_emformer/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
@ -127,7 +128,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless/exp",
|
||||
default="transducer_emformer/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -279,7 +280,7 @@ def get_params() -> AttributeDict:
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
"log_diagnostics": False,
|
||||
# parameters for conformer
|
||||
# parameters for Emformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
@ -287,10 +288,13 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"left_context_length": 120, # 120 frames
|
||||
"segment_length": 16,
|
||||
"right_context_length": 4,
|
||||
# parameters for decoder
|
||||
"embedding_dim": 512,
|
||||
# parameters for Noam
|
||||
"warm_step": 80000, # For the 100h subset, use 30000
|
||||
"warm_step": 80000, # For the 100h subset, use 20000
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -299,8 +303,7 @@ def get_params() -> AttributeDict:
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
encoder = Emformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.vocab_size,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
@ -309,6 +312,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
left_context_length=params.left_context_length,
|
||||
segment_length=params.segment_length,
|
||||
right_context_length=params.right_context_length,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -496,7 +502,11 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -725,7 +735,7 @@ def run(rank, world_size, args):
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 800
|
||||
params.warm_step = 30000
|
||||
params.warm_step = 20000
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user