decode and adapt and augmentation methods are added

This commit is contained in:
j-pong 2023-06-13 17:20:21 +09:00
parent 3445e63962
commit 50e60771da

View File

@ -61,6 +61,9 @@ from icefall.utils import (
import fairseq
from optim import Eden, ScaledAdam
from copy import deepcopy
LOG_EPS = math.log(1e-10)
def get_parser():
@ -261,6 +264,80 @@ def get_parser():
type=str,
)
# optimizer related
parser.add_argument(
"--base-lr", type=float, default=6e-5, help="The base learning rate."
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--ctc-loss-scale",
type=float,
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument(
"--subsampling-factor",
type=int,
default=320,
help="shit0",
)
parser.add_argument(
"--use-double-scores",
type=bool,
default=True,
help="shit0",
)
parser.add_argument(
"--warm-step",
type=int,
default=0,
help="shit0",
)
# tta related
parser.add_argument(
"--num_augment",
type=int,
default=4,
help="shit1",
)
add_model_arguments(parser)
add_rep_arguments(parser)
@ -475,6 +552,7 @@ def decode_one_batch(
def decode_and_adapt(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
@ -534,69 +612,40 @@ def decode_and_adapt(
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.ctc_loss_scale > 0:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions(
supervisions,
subsampling_factor=params.subsampling_factor,
token_ids=token_ids,
if params.ctc_loss_scale > 0:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions(
supervisions,
subsampling_factor=params.subsampling_factor,
token_ids=token_ids,
)
for i in range(params.num_augment):
supervision_segments[i][-1] = ctc_output.size(1)
# Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
# Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction="sum",
use_double_scores=params.use_double_scores,
)
assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction="sum",
use_double_scores=params.use_double_scores,
)
assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss
# self.adapted_model_losses.append(loss.item())
# self.adapted_models.append(self.copy_model_and_optimizer(self.models[0]))
self.optimizer.zero_grad()
assert loss.requires_grad == is_training
optimizer.zero_grad()
loss.backward()
self.optimizer.step()
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_output = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
optimizer.step()
def decode_dataset(
dl: torch.utils.data.DataLoader,
@ -606,30 +655,6 @@ def decode_dataset(
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
@ -643,6 +668,27 @@ def decode_dataset(
log_interval = 20
results = defaultdict(list)
parameters = []
parameters_name = []
for n, p in model.named_parameters():
if p.requires_grad:
if ("bias" in n) and ("encoder.layers" in n):
logging.info(f"{n} is free!")
parameters.append(p)
parameters_name.append(n)
else:
p.requires_grad = False
sizes = [p.numel() for p in parameters]
logging.info(f"total trainable parameter size : {sum(sizes)}")
optimizer = ScaledAdam(
parameters,
lr=params.base_lr,
clipping_scale=2.0,
parameters_names=[parameters_name],
)
for batch_idx, batch in enumerate(dl):
model.eval()
texts = batch["supervisions"]["text"]
@ -660,13 +706,22 @@ def decode_dataset(
)
# replace the supervision to pseudo labels
batch["supervision"]["text"] = "".join(hyps_dict[params.decoding_method] )
pseudo_batch = deepcopy(batch)
assert len(hyps_dict[params.decoding_method]) == 1 # shoud use the single utterance sampler
pseudo_batch["supervisions"]["text"] = [" ".join(hyps_dict[params.decoding_method][0]).lower()] * params.num_augment
# augment the single utterance (augmentation automatically excued in d2v model)
batch["intputs"] = batch["intputs"].reapeat(4, 1)
pseudo_batch["inputs"] = batch["inputs"].repeat(params.num_augment, 1)
pseudo_batch["supervisions"]["sequence_idx"] = batch["supervisions"]["sequence_idx"].repeat(params.num_augment)
pseudo_batch["supervisions"]['cut'] = batch["supervisions"]['cut'] * params.num_augment
assert "start_frame" not in pseudo_batch["supervisions"].keys()
assert "num_frames" not in pseudo_batch["supervisions"].keys()
decode_and_adapt(params, model, sp, batch, is_training=True, num_iter=10)
# model.train() is excuted in the decoder and adpt fucntion
decode_and_adapt(params, model, optimizer, sp, pseudo_batch, is_training=True, num_iter=10)
model.eval()
hyps_dict = decode_one_batch(
params=params,
model=model,