mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
decode and adapt and augmentation methods are added
This commit is contained in:
parent
3445e63962
commit
50e60771da
@ -61,6 +61,9 @@ from icefall.utils import (
|
||||
|
||||
import fairseq
|
||||
|
||||
from optim import Eden, ScaledAdam
|
||||
from copy import deepcopy
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
def get_parser():
|
||||
@ -261,6 +264,80 @@ def get_parser():
|
||||
type=str,
|
||||
)
|
||||
|
||||
# optimizer related
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=6e-5, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prune-range",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||
"we are using to compute the loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="The scale to smooth the loss with lm "
|
||||
"(output of prediction network) part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network) part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simple-loss-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="To get pruning ranges, we will calculate a simple version"
|
||||
"loss(joiner is just addition), this simple loss also uses for"
|
||||
"training (as a regularization item). We will scale the simple loss"
|
||||
"with this parameter before adding to the final loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ctc-loss-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Scale for CTC loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--subsampling-factor",
|
||||
type=int,
|
||||
default=320,
|
||||
help="shit0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-double-scores",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="shit0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm-step",
|
||||
type=int,
|
||||
default=0,
|
||||
help="shit0",
|
||||
)
|
||||
|
||||
# tta related
|
||||
parser.add_argument(
|
||||
"--num_augment",
|
||||
type=int,
|
||||
default=4,
|
||||
help="shit1",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
add_rep_arguments(parser)
|
||||
|
||||
@ -475,6 +552,7 @@ def decode_one_batch(
|
||||
def decode_and_adapt(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
@ -534,69 +612,40 @@ def decode_and_adapt(
|
||||
|
||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
if params.ctc_loss_scale > 0:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
supervision_segments, token_ids = encode_supervisions(
|
||||
supervisions,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
token_ids=token_ids,
|
||||
if params.ctc_loss_scale > 0:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
supervision_segments, token_ids = encode_supervisions(
|
||||
supervisions,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
for i in range(params.num_augment):
|
||||
supervision_segments[i][-1] = ctc_output.size(1)
|
||||
|
||||
# Works with a BPE model
|
||||
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
ctc_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
# Works with a BPE model
|
||||
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
ctc_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
reduction="sum",
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
assert ctc_loss.requires_grad == is_training
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
reduction="sum",
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
assert ctc_loss.requires_grad == is_training
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
|
||||
# self.adapted_model_losses.append(loss.item())
|
||||
# self.adapted_models.append(self.copy_model_and_optimizer(self.models[0]))
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_output = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
|
||||
|
||||
|
||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
optimizer.step()
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
@ -606,30 +655,6 @@ def decode_dataset(
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
@ -643,6 +668,27 @@ def decode_dataset(
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
|
||||
parameters = []
|
||||
parameters_name = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
if ("bias" in n) and ("encoder.layers" in n):
|
||||
logging.info(f"{n} is free!")
|
||||
parameters.append(p)
|
||||
parameters_name.append(n)
|
||||
else:
|
||||
p.requires_grad = False
|
||||
sizes = [p.numel() for p in parameters]
|
||||
logging.info(f"total trainable parameter size : {sum(sizes)}")
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
parameters,
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
parameters_names=[parameters_name],
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
model.eval()
|
||||
texts = batch["supervisions"]["text"]
|
||||
@ -660,13 +706,22 @@ def decode_dataset(
|
||||
)
|
||||
|
||||
# replace the supervision to pseudo labels
|
||||
batch["supervision"]["text"] = "".join(hyps_dict[params.decoding_method] )
|
||||
pseudo_batch = deepcopy(batch)
|
||||
|
||||
assert len(hyps_dict[params.decoding_method]) == 1 # shoud use the single utterance sampler
|
||||
pseudo_batch["supervisions"]["text"] = [" ".join(hyps_dict[params.decoding_method][0]).lower()] * params.num_augment
|
||||
|
||||
# augment the single utterance (augmentation automatically excued in d2v model)
|
||||
batch["intputs"] = batch["intputs"].reapeat(4, 1)
|
||||
pseudo_batch["inputs"] = batch["inputs"].repeat(params.num_augment, 1)
|
||||
pseudo_batch["supervisions"]["sequence_idx"] = batch["supervisions"]["sequence_idx"].repeat(params.num_augment)
|
||||
pseudo_batch["supervisions"]['cut'] = batch["supervisions"]['cut'] * params.num_augment
|
||||
assert "start_frame" not in pseudo_batch["supervisions"].keys()
|
||||
assert "num_frames" not in pseudo_batch["supervisions"].keys()
|
||||
|
||||
decode_and_adapt(params, model, sp, batch, is_training=True, num_iter=10)
|
||||
# model.train() is excuted in the decoder and adpt fucntion
|
||||
decode_and_adapt(params, model, optimizer, sp, pseudo_batch, is_training=True, num_iter=10)
|
||||
|
||||
model.eval()
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user