mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
change context size to 1
This commit is contained in:
parent
e7a2decfe4
commit
dc40220951
@ -74,7 +74,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
log "Stage 1: Prepare aishell2 manifest"
|
log "Stage 1: Prepare aishell2 manifest"
|
||||||
# We assume that you have downloaded and unzip the aishell2 corpus
|
# We assume that you have downloaded and unzip the aishell2 corpus
|
||||||
# to $dl_dir/aishell2
|
# to $dl_dir/aishell2
|
||||||
if [ ! -f data/manifests/.aishell_manifests.done ]; then
|
if [ ! -f data/manifests/.aishell2_manifests.done ]; then
|
||||||
mkdir -p data/manifests
|
mkdir -p data/manifests
|
||||||
lhotse prepare aishell2 $dl_dir/aishell2 data/manifests -j $nj
|
lhotse prepare aishell2 $dl_dir/aishell2 data/manifests -j $nj
|
||||||
touch data/manifests/.aishell2_manifests.done
|
touch data/manifests/.aishell2_manifests.done
|
||||||
@ -94,7 +94,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Compute fbank for aishell"
|
log "Stage 3: Compute fbank for aishell2"
|
||||||
if [ ! -f data/fbank/.aishell2.done ]; then
|
if [ ! -f data/fbank/.aishell2.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell2.py
|
./local/compute_fbank_aishell2.py
|
||||||
@ -129,7 +129,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
|
|
||||||
# The implementation of chinese word segmentation for text,
|
# The implementation of chinese word segmentation for text,
|
||||||
# and it will take about 15 minutes.
|
# and it will take about 15 minutes.
|
||||||
# If can't install paddle-tiny with python 3.8, please refer
|
# If you can't install paddle-tiny with python 3.8, please refer to
|
||||||
# https://github.com/fxsjy/jieba/issues/920
|
# https://github.com/fxsjy/jieba/issues/920
|
||||||
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
|
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
|
||||||
python3 ./local/text2segments.py \
|
python3 ./local/text2segments.py \
|
||||||
@ -149,4 +149,4 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
|
||||||
python3 ./local/prepare_char.py
|
python3 ./local/prepare_char.py
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
File diff suppressed because it is too large
Load Diff
1
egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py
Symbolic link
1
egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
|
@ -1,65 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
To run this file, do:
|
|
||||||
|
|
||||||
cd icefall/egs/librispeech/ASR
|
|
||||||
python ./pruned_transducer_stateless4/test_model.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
from train import get_params, get_transducer_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_1():
|
|
||||||
params = get_params()
|
|
||||||
params.vocab_size = 500
|
|
||||||
params.blank_id = 0
|
|
||||||
params.context_size = 2
|
|
||||||
params.num_encoder_layers = 24
|
|
||||||
params.dim_feedforward = 1536 # 384 * 4
|
|
||||||
params.encoder_dim = 384
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
print(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
|
|
||||||
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
|
|
||||||
def test_model_M():
|
|
||||||
params = get_params()
|
|
||||||
params.vocab_size = 500
|
|
||||||
params.blank_id = 0
|
|
||||||
params.context_size = 2
|
|
||||||
params.num_encoder_layers = 18
|
|
||||||
params.dim_feedforward = 1024
|
|
||||||
params.encoder_dim = 256
|
|
||||||
params.nhead = 4
|
|
||||||
params.decoder_dim = 512
|
|
||||||
params.joiner_dim = 512
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
print(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# test_model_1()
|
|
||||||
test_model_M()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -243,7 +243,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=1,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
@ -398,7 +398,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
@ -524,9 +524,6 @@ def load_checkpoint_if_available(
|
|||||||
if "cur_epoch" in saved_params:
|
if "cur_epoch" in saved_params:
|
||||||
params["start_epoch"] = saved_params["cur_epoch"]
|
params["start_epoch"] = saved_params["cur_epoch"]
|
||||||
|
|
||||||
if "cur_batch_idx" in saved_params:
|
|
||||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
@ -748,12 +745,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx < cur_batch_idx:
|
|
||||||
continue
|
|
||||||
cur_batch_idx = batch_idx
|
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
@ -779,8 +771,9 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params,
|
display_and_save_batch(
|
||||||
graph_compiler=graph_compiler)
|
batch, params=params, graph_compiler=graph_compiler
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -801,7 +794,6 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train > 0
|
params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
):
|
):
|
||||||
params.cur_batch_idx = batch_idx
|
|
||||||
save_checkpoint_with_global_batch_idx(
|
save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
global_batch_idx=params.batch_idx_train,
|
||||||
@ -814,7 +806,6 @@ def train_one_epoch(
|
|||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
del params.cur_batch_idx
|
|
||||||
remove_checkpoints(
|
remove_checkpoints(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
topk=params.keep_last_k,
|
topk=params.keep_last_k,
|
||||||
@ -1113,8 +1104,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params,
|
display_and_save_batch(
|
||||||
graph_compiler=graph_compiler)
|
batch, params=params, graph_compiler=graph_compiler
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user