diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/decode_and_adapt.py b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/decode_and_adapt.py index 7d6a1a564..87dde922f 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/decode_and_adapt.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless_d2v_v2/decode_and_adapt.py @@ -61,6 +61,9 @@ from icefall.utils import ( import fairseq +from optim import Eden, ScaledAdam +from copy import deepcopy + LOG_EPS = math.log(1e-10) def get_parser(): @@ -261,6 +264,80 @@ def get_parser(): type=str, ) + # optimizer related + parser.add_argument( + "--base-lr", type=float, default=6e-5, help="The base learning rate." + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--subsampling-factor", + type=int, + default=320, + help="shit0", + ) + + parser.add_argument( + "--use-double-scores", + type=bool, + default=True, + help="shit0", + ) + + parser.add_argument( + "--warm-step", + type=int, + default=0, + help="shit0", + ) + + # tta related + parser.add_argument( + "--num_augment", + type=int, + default=4, + help="shit1", + ) + add_model_arguments(parser) add_rep_arguments(parser) @@ -475,6 +552,7 @@ def decode_one_batch( def decode_and_adapt( params: AttributeDict, model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, @@ -534,69 +612,40 @@ def decode_and_adapt( loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - if params.ctc_loss_scale > 0: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - supervision_segments, token_ids = encode_supervisions( - supervisions, - subsampling_factor=params.subsampling_factor, - token_ids=token_ids, + if params.ctc_loss_scale > 0: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + for i in range(params.num_augment): + supervision_segments[i][-1] = ctc_output.size(1) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, ) - - # Works with a BPE model - decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) - dense_fsa_vec = k2.DenseFsaVec( - ctc_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction="sum", - use_double_scores=params.use_double_scores, - ) - assert ctc_loss.requires_grad == is_training - loss += params.ctc_loss_scale * ctc_loss + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss - # self.adapted_model_losses.append(loss.item()) - # self.adapted_models.append(self.copy_model_and_optimizer(self.models[0])) - - self.optimizer.zero_grad() + assert loss.requires_grad == is_training + + optimizer.zero_grad() loss.backward() - self.optimizer.step() - - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_output = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - + optimizer.step() def decode_dataset( dl: torch.utils.data.DataLoader, @@ -606,30 +655,6 @@ def decode_dataset( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ num_cuts = 0 try: @@ -643,6 +668,27 @@ def decode_dataset( log_interval = 20 results = defaultdict(list) + + parameters = [] + parameters_name = [] + for n, p in model.named_parameters(): + if p.requires_grad: + if ("bias" in n) and ("encoder.layers" in n): + logging.info(f"{n} is free!") + parameters.append(p) + parameters_name.append(n) + else: + p.requires_grad = False + sizes = [p.numel() for p in parameters] + logging.info(f"total trainable parameter size : {sum(sizes)}") + + optimizer = ScaledAdam( + parameters, + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=[parameters_name], + ) + for batch_idx, batch in enumerate(dl): model.eval() texts = batch["supervisions"]["text"] @@ -660,13 +706,22 @@ def decode_dataset( ) # replace the supervision to pseudo labels - batch["supervision"]["text"] = "".join(hyps_dict[params.decoding_method] ) + pseudo_batch = deepcopy(batch) + + assert len(hyps_dict[params.decoding_method]) == 1 # shoud use the single utterance sampler + pseudo_batch["supervisions"]["text"] = [" ".join(hyps_dict[params.decoding_method][0]).lower()] * params.num_augment # augment the single utterance (augmentation automatically excued in d2v model) - batch["intputs"] = batch["intputs"].reapeat(4, 1) + pseudo_batch["inputs"] = batch["inputs"].repeat(params.num_augment, 1) + pseudo_batch["supervisions"]["sequence_idx"] = batch["supervisions"]["sequence_idx"].repeat(params.num_augment) + pseudo_batch["supervisions"]['cut'] = batch["supervisions"]['cut'] * params.num_augment + assert "start_frame" not in pseudo_batch["supervisions"].keys() + assert "num_frames" not in pseudo_batch["supervisions"].keys() - decode_and_adapt(params, model, sp, batch, is_training=True, num_iter=10) + # model.train() is excuted in the decoder and adpt fucntion + decode_and_adapt(params, model, optimizer, sp, pseudo_batch, is_training=True, num_iter=10) + model.eval() hyps_dict = decode_one_batch( params=params, model=model,