Disable causal training; add balancers in decoder.

This commit is contained in:
Daniel Povey 2023-02-11 22:51:09 +08:00
parent f9f546968c
commit dc481ca419
3 changed files with 225 additions and 3 deletions

View File

@ -18,6 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from scaling import Balancer
class Decoder(nn.Module): class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper: """This class modifies the stateless decoder from the following paper:
@ -58,11 +59,19 @@ class Decoder(nn.Module):
embedding_dim=decoder_dim, embedding_dim=decoder_dim,
padding_idx=blank_id, padding_idx=blank_id,
) )
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.blank_id = blank_id self.blank_id = blank_id
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size
self.vocab_size = vocab_size self.vocab_size = vocab_size
if context_size > 1: if context_size > 1:
self.conv = nn.Conv1d( self.conv = nn.Conv1d(
in_channels=decoder_dim, in_channels=decoder_dim,
@ -72,6 +81,11 @@ class Decoder(nn.Module):
groups=decoder_dim//4, # group size == 4 groups=decoder_dim//4, # group size == 4
bias=False, bias=False,
) )
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
@ -88,6 +102,9 @@ class Decoder(nn.Module):
# this stuff about clamp() is a temporary fix for a mismatch # this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py # at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
embedding_out = self.balancer(embedding_out)
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
@ -100,5 +117,7 @@ class Decoder(nn.Module):
assert embedding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out) embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out) embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
return embedding_out return embedding_out

View File

@ -0,0 +1,203 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) use the checkpoint exp_dir/epoch-xxx.pt
./pruned_transducer_stateless7/generate_averaged_model.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
(2) use the checkpoint exp_dir/checkpoint-iter.pt
./pruned_transducer_stateless7/generate_averaged_model.py \
--iter 22000 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless7/exp
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
"""
import argparse
from pathlib import Path
from typing import Dict, List
import sentencepiece as spm
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints_with_averaged_model,
find_checkpoints,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
print("Script started")
device = torch.device("cpu")
print(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
print("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
print(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
print(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
print("Done!")
if __name__ == "__main__":
main()

View File

@ -230,14 +230,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--causal", "--causal",
type=str2bool, type=str2bool,
default=True, default=False,
help="If True, use causal version of model.", help="If True, use causal version of model.",
) )
parser.add_argument( parser.add_argument(
"--chunk-size", "--chunk-size",
type=str, type=str,
default="16,32,64,-1", default="-1", # "16,32,64,-1",
help="Chunk sizes will be chosen randomly from this list during training. " help="Chunk sizes will be chosen randomly from this list during training. "
" Must be just -1 if --causal=False" " Must be just -1 if --causal=False"
) )