mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Remove batchnorm, weight decay, and SOS from transducer conformer encoder (#155)
* Remove batchnorm, weight decay, and SOS. * Make --context-size configurable. * Update results.
This commit is contained in:
parent
8187d6236c
commit
14c93add50
@ -74,11 +74,11 @@ jobs:
|
|||||||
mkdir tmp
|
mkdir tmp
|
||||||
cd tmp
|
cd tmp
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27
|
||||||
cd ..
|
cd ..
|
||||||
tree tmp
|
tree tmp
|
||||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav
|
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
|
||||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav
|
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
|
||||||
|
|
||||||
- name: Run greedy search decoding
|
- name: Run greedy search decoding
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -87,11 +87,11 @@ jobs:
|
|||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
|
||||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
- name: Run beam search decoding
|
- name: Run beam search decoding
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -101,8 +101,8 @@ jobs:
|
|||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
|
||||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -84,7 +84,7 @@ The best WER using beam search with beam size 4 is:
|
|||||||
|
|
||||||
| | test-clean | test-other |
|
| | test-clean | test-other |
|
||||||
|-----|------------|------------|
|
|-----|------------|------------|
|
||||||
| WER | 2.92 | 7.37 |
|
| WER | 2.83 | 7.19 |
|
||||||
|
|
||||||
Note: No auxiliary losses are used in the training and no LMs are used
|
Note: No auxiliary losses are used in the training and no LMs are used
|
||||||
in the decoding.
|
in the decoding.
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#### Conformer encoder + embedding decoder
|
#### Conformer encoder + embedding decoder
|
||||||
|
|
||||||
Using commit `fb6a57e9e01dd8aae2af2a6b4568daad8bc8ab32`.
|
Using commit `TODO`.
|
||||||
|
|
||||||
Conformer encoder + non-current decoder. The decoder
|
Conformer encoder + non-current decoder. The decoder
|
||||||
contains only an embedding layer and a Conv1d (with kernel size 2).
|
contains only an embedding layer and a Conv1d (with kernel size 2).
|
||||||
@ -13,12 +13,8 @@ The WERs are
|
|||||||
|
|
||||||
| | test-clean | test-other | comment |
|
| | test-clean | test-other | comment |
|
||||||
|---------------------------|------------|------------|------------------------------------------|
|
|---------------------------|------------|------------|------------------------------------------|
|
||||||
| greedy search | 2.99 | 7.52 | --epoch 20, --avg 10, --max-duration 100 |
|
| greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 |
|
||||||
| beam search (beam size 2) | 2.95 | 7.43 | |
|
| beam search (beam size 4) | 2.83 | 7.19 | |
|
||||||
| beam search (beam size 3) | 2.94 | 7.37 | |
|
|
||||||
| beam search (beam size 4) | 2.92 | 7.37 | |
|
|
||||||
| beam search (beam size 5) | 2.93 | 7.38 | |
|
|
||||||
| beam search (beam size 8) | 2.92 | 7.38 | |
|
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
@ -36,12 +32,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
```
|
```
|
||||||
|
|
||||||
The tensorboard training log can be found at
|
The tensorboard training log can be found at
|
||||||
<https://tensorboard.dev/experiment/PsJ3LgkEQfOmzedAlYfVeg/#scalars&_smoothingWeight=0>
|
<https://tensorboard.dev/experiment/Mjx7MeTgR3Oyr1yBCwjozw/>
|
||||||
|
|
||||||
The decoding command is:
|
The decoding command is:
|
||||||
```
|
```
|
||||||
epoch=20
|
epoch=29
|
||||||
avg=10
|
avg=13
|
||||||
|
|
||||||
## greedy search
|
## greedy search
|
||||||
./transducer_stateless/decode.py \
|
./transducer_stateless/decode.py \
|
||||||
@ -64,7 +60,7 @@ avg=10
|
|||||||
|
|
||||||
|
|
||||||
#### Conformer encoder + LSTM decoder
|
#### Conformer encoder + LSTM decoder
|
||||||
Using commit `TODO`.
|
Using commit `8187d6236c2926500da5ee854f758e621df803cc`.
|
||||||
|
|
||||||
Conformer encoder + LSTM decoder.
|
Conformer encoder + LSTM decoder.
|
||||||
|
|
||||||
|
@ -396,7 +396,7 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
@ -194,7 +194,7 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
@ -208,7 +208,7 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
@ -564,7 +564,7 @@ def run(rank, world_size, args):
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
@ -56,7 +56,6 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -69,7 +68,6 @@ class Conformer(Transformer):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
normalize_before=normalize_before,
|
normalize_before=normalize_before,
|
||||||
vgg_frontend=vgg_frontend,
|
vgg_frontend=vgg_frontend,
|
||||||
use_feat_batchnorm=use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
@ -107,11 +105,6 @@ class Conformer(Transformer):
|
|||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(channels)
|
self.norm = nn.LayerNorm(channels)
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.activation(self.norm(x))
|
# x is (batch, channels, time)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
@ -70,14 +70,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=29,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=13,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -114,6 +114,13 @@ def get_parser():
|
|||||||
help="Used only when --decoding-method is beam_search",
|
help="Used only when --decoding-method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
@ -136,9 +143,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -156,7 +160,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -393,6 +396,7 @@ def main():
|
|||||||
if params.decoding_method == "beam_search":
|
if params.decoding_method == "beam_search":
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
else:
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
@ -20,13 +20,14 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""This class implements the stateless decoder from the following paper:
|
"""This class modifies the stateless decoder from the following paper:
|
||||||
|
|
||||||
RNN-transducer with stateless prediction network
|
RNN-transducer with stateless prediction network
|
||||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||||
|
|
||||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||||
network.
|
network. Different from the above paper, it adds an extra Conv1d
|
||||||
|
right after the embedding layer.
|
||||||
|
|
||||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||||
"""
|
"""
|
||||||
|
@ -104,6 +104,14 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
@ -48,7 +47,7 @@ class Joiner(nn.Module):
|
|||||||
# Now decoder_out is (N, 1, U, C)
|
# Now decoder_out is (N, 1, U, C)
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
logit = F.relu(logit)
|
logit = torch.tanh(logit)
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
output = self.output_linear(logit)
|
||||||
|
|
||||||
|
@ -110,6 +110,13 @@ def get_parser():
|
|||||||
help="Used only when --method is beam_search",
|
help="Used only when --method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
@ -135,9 +142,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -154,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -130,6 +130,14 @@ def get_parser():
|
|||||||
help="The lr_factor for Noam optimizer",
|
help="The lr_factor for Noam optimizer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -171,15 +179,10 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
|
||||||
input features.
|
|
||||||
|
|
||||||
- attention_dim: Hidden dim for multi-head attention model.
|
- attention_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
- weight_decay: The weight_decay for the optimizer.
|
|
||||||
|
|
||||||
- warm_step: The warm_step for Noam optimizer.
|
- warm_step: The warm_step for Noam optimizer.
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
@ -201,11 +204,7 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"weight_decay": 1e-6,
|
|
||||||
"warm_step": 80000, # For the 100h subset, use 8k
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -225,7 +224,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -568,7 +566,7 @@ def run(rank, world_size, args):
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
@ -593,7 +591,6 @@ def run(rank, world_size, args):
|
|||||||
model_size=params.attention_dim,
|
model_size=params.attention_dim,
|
||||||
factor=params.lr_factor,
|
factor=params.lr_factor,
|
||||||
warm_step=params.warm_step,
|
warm_step=params.warm_step,
|
||||||
weight_decay=params.weight_decay,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
|
@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
|
|||||||
If True, use pre-layer norm; False to use post-layer norm.
|
If True, use pre-layer norm; False to use post-layer norm.
|
||||||
vgg_frontend:
|
vgg_frontend:
|
||||||
True to use vgg style frontend for subsampling.
|
True to use vgg style frontend for subsampling.
|
||||||
use_feat_batchnorm:
|
|
||||||
True to use batchnorm for the input layer.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_feat_batchnorm = use_feat_batchnorm
|
|
||||||
if use_feat_batchnorm:
|
|
||||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
|
|||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = self.encoder_pos(x)
|
x = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user