mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
decode and adapt and augmentation methods are added
This commit is contained in:
parent
3445e63962
commit
50e60771da
@ -61,6 +61,9 @@ from icefall.utils import (
|
|||||||
|
|
||||||
import fairseq
|
import fairseq
|
||||||
|
|
||||||
|
from optim import Eden, ScaledAdam
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -261,6 +264,80 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# optimizer related
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-lr", type=float, default=6e-5, help="The base learning rate."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prune-range",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||||
|
"we are using to compute the loss",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.25,
|
||||||
|
help="The scale to smooth the loss with lm "
|
||||||
|
"(output of prediction network) part.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--am-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="The scale to smooth the loss with am (output of encoder network) part.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simple-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="To get pruning ranges, we will calculate a simple version"
|
||||||
|
"loss(joiner is just addition), this simple loss also uses for"
|
||||||
|
"training (as a regularization item). We will scale the simple loss"
|
||||||
|
"with this parameter before adding to the final loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctc-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Scale for CTC loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--subsampling-factor",
|
||||||
|
type=int,
|
||||||
|
default=320,
|
||||||
|
help="shit0",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-double-scores",
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help="shit0",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--warm-step",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="shit0",
|
||||||
|
)
|
||||||
|
|
||||||
|
# tta related
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_augment",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="shit1",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
add_rep_arguments(parser)
|
add_rep_arguments(parser)
|
||||||
|
|
||||||
@ -475,6 +552,7 @@ def decode_one_batch(
|
|||||||
def decode_and_adapt(
|
def decode_and_adapt(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
@ -534,69 +612,40 @@ def decode_and_adapt(
|
|||||||
|
|
||||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
|
|
||||||
if params.ctc_loss_scale > 0:
|
if params.ctc_loss_scale > 0:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
supervision_segments, token_ids = encode_supervisions(
|
supervision_segments, token_ids = encode_supervisions(
|
||||||
supervisions,
|
supervisions,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
for i in range(params.num_augment):
|
||||||
|
supervision_segments[i][-1] = ctc_output.size(1)
|
||||||
|
|
||||||
|
# Works with a BPE model
|
||||||
|
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
ctc_output,
|
||||||
|
supervision_segments,
|
||||||
|
allow_truncate=params.subsampling_factor - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Works with a BPE model
|
ctc_loss = k2.ctc_loss(
|
||||||
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
decoding_graph=decoding_graph,
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
ctc_output,
|
output_beam=params.beam_size,
|
||||||
supervision_segments,
|
reduction="sum",
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
use_double_scores=params.use_double_scores,
|
||||||
)
|
)
|
||||||
|
assert ctc_loss.requires_grad == is_training
|
||||||
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
|
||||||
ctc_loss = k2.ctc_loss(
|
assert loss.requires_grad == is_training
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
dense_fsa_vec=dense_fsa_vec,
|
|
||||||
output_beam=params.beam_size,
|
|
||||||
reduction="sum",
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
assert ctc_loss.requires_grad == is_training
|
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
|
||||||
|
|
||||||
# self.adapted_model_losses.append(loss.item())
|
optimizer.zero_grad()
|
||||||
# self.adapted_models.append(self.copy_model_and_optimizer(self.models[0]))
|
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
|
||||||
simple_loss, pruned_loss, ctc_output = model(
|
|
||||||
x=feature,
|
|
||||||
x_lens=feature_lens,
|
|
||||||
y=y,
|
|
||||||
prune_range=params.prune_range,
|
|
||||||
am_scale=params.am_scale,
|
|
||||||
lm_scale=params.lm_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
|
||||||
# to params.simple_loss scale by warm_step.
|
|
||||||
simple_loss_scale = (
|
|
||||||
s
|
|
||||||
if batch_idx_train >= warm_step
|
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
|
||||||
)
|
|
||||||
pruned_loss_scale = (
|
|
||||||
1.0
|
|
||||||
if batch_idx_train >= warm_step
|
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
@ -606,30 +655,6 @@ def decode_dataset(
|
|||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dl:
|
|
||||||
PyTorch's dataloader containing the dataset to decode.
|
|
||||||
params:
|
|
||||||
It is returned by :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
word_table:
|
|
||||||
The word symbol table.
|
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
|
||||||
Returns:
|
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
|
||||||
The first is the reference transcript, and the second is the
|
|
||||||
predicted result.
|
|
||||||
"""
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -643,6 +668,27 @@ def decode_dataset(
|
|||||||
log_interval = 20
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
|
|
||||||
|
parameters = []
|
||||||
|
parameters_name = []
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
if ("bias" in n) and ("encoder.layers" in n):
|
||||||
|
logging.info(f"{n} is free!")
|
||||||
|
parameters.append(p)
|
||||||
|
parameters_name.append(n)
|
||||||
|
else:
|
||||||
|
p.requires_grad = False
|
||||||
|
sizes = [p.numel() for p in parameters]
|
||||||
|
logging.info(f"total trainable parameter size : {sum(sizes)}")
|
||||||
|
|
||||||
|
optimizer = ScaledAdam(
|
||||||
|
parameters,
|
||||||
|
lr=params.base_lr,
|
||||||
|
clipping_scale=2.0,
|
||||||
|
parameters_names=[parameters_name],
|
||||||
|
)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
model.eval()
|
model.eval()
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
@ -660,13 +706,22 @@ def decode_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# replace the supervision to pseudo labels
|
# replace the supervision to pseudo labels
|
||||||
batch["supervision"]["text"] = "".join(hyps_dict[params.decoding_method] )
|
pseudo_batch = deepcopy(batch)
|
||||||
|
|
||||||
|
assert len(hyps_dict[params.decoding_method]) == 1 # shoud use the single utterance sampler
|
||||||
|
pseudo_batch["supervisions"]["text"] = [" ".join(hyps_dict[params.decoding_method][0]).lower()] * params.num_augment
|
||||||
|
|
||||||
# augment the single utterance (augmentation automatically excued in d2v model)
|
# augment the single utterance (augmentation automatically excued in d2v model)
|
||||||
batch["intputs"] = batch["intputs"].reapeat(4, 1)
|
pseudo_batch["inputs"] = batch["inputs"].repeat(params.num_augment, 1)
|
||||||
|
pseudo_batch["supervisions"]["sequence_idx"] = batch["supervisions"]["sequence_idx"].repeat(params.num_augment)
|
||||||
|
pseudo_batch["supervisions"]['cut'] = batch["supervisions"]['cut'] * params.num_augment
|
||||||
|
assert "start_frame" not in pseudo_batch["supervisions"].keys()
|
||||||
|
assert "num_frames" not in pseudo_batch["supervisions"].keys()
|
||||||
|
|
||||||
decode_and_adapt(params, model, sp, batch, is_training=True, num_iter=10)
|
# model.train() is excuted in the decoder and adpt fucntion
|
||||||
|
decode_and_adapt(params, model, optimizer, sp, pseudo_batch, is_training=True, num_iter=10)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user