From bb1c4466e39e3350a8d219b595ae490e2fe51bcf Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 11 Sep 2023 18:46:38 -0700 Subject: [PATCH] rename train, train2, add support to fine-tune embedding table --- egs/aishell/ASR/seamlessm4t/decode2.py | 19 +- .../seamlessm4t/patch/sequence_generator.py | 694 ++++++++++++++++++ egs/aishell/ASR/seamlessm4t/train.py | 47 +- egs/aishell/ASR/seamlessm4t/train2.py | 61 +- 4 files changed, 769 insertions(+), 52 deletions(-) create mode 100644 egs/aishell/ASR/seamlessm4t/patch/sequence_generator.py diff --git a/egs/aishell/ASR/seamlessm4t/decode2.py b/egs/aishell/ASR/seamlessm4t/decode2.py index dc5e5ea95..4607bae07 100644 --- a/egs/aishell/ASR/seamlessm4t/decode2.py +++ b/egs/aishell/ASR/seamlessm4t/decode2.py @@ -191,12 +191,12 @@ def decode_one_batch( feature_len = supervisions["num_frames"] feature_len = feature_len.to(device, dtype=dtype) - #text_output = s2t_generator.generate_ex(feature, feature_len) + text_output = s2t_generator.generate_ex(feature, feature_len) #sentences = text_output.sentences #hyps = [sentence.bytes().decode("utf-8").split() for sentence in sentences] - token_ids = text_output.generator_output.results[0][0].seq.cpu().tolist() - hyps_ids = [setence[0].seq.cpu().tolist() for sentence in token_ids] + token_ids = text_output.generator_output.results + hyps_ids = [sentence[0].seq.cpu().tolist() for sentence in token_ids] hyps = [params.tokenizer.decode(hyps_id).split() for hyps_id in hyps_ids] key = "beam-search" @@ -347,8 +347,10 @@ def main(): del model.t2u_model del model.text_encoder del model.text_encoder_frontend - model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenzier.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True) - model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size) + model.text_decoder_frontend.embed = nn.Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,padding_idx=0) + #model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True) + model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size, bias=False) + #model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size) if params.epoch > 0: if params.avg > 1: start = params.epoch - params.avg @@ -371,6 +373,13 @@ def main(): load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() + model.half() + #for param in model.parameters(): + # if param.dtype == torch.float16: + # pass + # else: + # param.data = param.data.to(torch.float16) + #print(param) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") diff --git a/egs/aishell/ASR/seamlessm4t/patch/sequence_generator.py b/egs/aishell/ASR/seamlessm4t/patch/sequence_generator.py new file mode 100644 index 000000000..d13cc08d0 --- /dev/null +++ b/egs/aishell/ASR/seamlessm4t/patch/sequence_generator.py @@ -0,0 +1,694 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union, cast + +import torch +from torch import Tensor +from torch.nn.functional import log_softmax + +from fairseq2.data import Collater, SequenceData, VocabularyInfo +from fairseq2.generation.beam_search import BeamSearch, StandardBeamSearch +from fairseq2.generation.logits_processor import LogitsProcessor +from fairseq2.models.encoder_decoder import Seq2SeqDecoder +from fairseq2.nn.incremental_state import IncrementalStateBag +from fairseq2.typing import Device + + +@dataclass +class SequenceGeneratorOptions: + """Holds the options to pass to a sequence generator.""" + + beam_size: int = 5 + """The beam size.""" + + min_seq_len: int = 1 + """The minimum length of generated sequences (including prefix sequence).""" + + soft_max_seq_len: Optional[Tuple[int, int]] = (1, 200) + """The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source + sequence length. The generated sequences (including prefix sequence) will + have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also + ``hard_max_seq_len``.""" + + hard_max_seq_len: int = 1024 + """The hard limit on maximum length of generated sequences.""" + + len_penalty: float = 1.0 + """The length penalty, where values less than 1.0 favor shorter, values + greater than 1.0 favor longer sequences.""" + + unk_penalty: float = 0.0 + """The unknown symbol penalty, where values less than 0 produce more UNKs, + values greater than 0 produce fewer UNKs.""" + + normalize_scores: bool = True + """If ``True``, normalizes scores by the length of generated sequences.""" + + search: Optional[BeamSearch] = None + """The beam search algorithm to use.""" + + logits_processor: Optional[LogitsProcessor] = None + """Logits processor called before applying beam search step.""" + + +class Seq2SeqGenerator: + """Represents a sequence-to-sequence generator.""" + + decoder: Seq2SeqDecoder + opts: SequenceGeneratorOptions + beam_size: int + eos_idx: int + pad_idx: Optional[int] + unk_idx: Optional[int] + prefix_seq: Union[int, Tensor] + prefix_seq_len: int + search: BeamSearch + logits_processor: Optional[LogitsProcessor] + collater: Collater + + def __init__( + self, + decoder: Seq2SeqDecoder, + vocab_info: VocabularyInfo, + prefix_seq: Optional[Union[int, Tensor]], + opts: Optional[SequenceGeneratorOptions] = None, + ) -> None: + """ + :param decoder: + The decoder to use. + :param vocab_info: + The vocabulary information to use. + :param prefix_seq: + The prefix sequence, typically one or more control symbols + indicating the beginning of a sequence. *Shape:* :math:`()` or + :math:`(S)`, where :math:`S` is the sequence length. If ``None``, + the EOS symbol will be used as prefix. + :param opts: + The generation options. + """ + self.decoder = decoder + + self.opts = opts or SequenceGeneratorOptions() + + # Set beam size. + if vocab_info.pad_idx is None: + self.beam_size = min(self.opts.beam_size, vocab_info.size) + else: + # -1 since we never select PAD. + self.beam_size = min(self.opts.beam_size, vocab_info.size - 1) + + if vocab_info.eos_idx is None: + raise ValueError( + "`vocab_info` must have `eos_idx` set for sequence generation." + ) + + # Set vocab info. + self.eos_idx = 1 + #self.eos_idx = vocab_info.eos_idx + self.unk_idx = 2 + #self.unk_idx = vocab_info.unk_idx + self.pad_idx = 0 + #self.pad_idx = vocab_info.pad_idx + + # Set prefix sequence. + if 1: + #if prefix_seq is None: + # If `None`, we follow fairseq's convention, and use EOS as the + # prefix. + self.prefix_seq, self.prefix_seq_len = self.eos_idx, 1 + else: + self.prefix_seq = prefix_seq + + if isinstance(prefix_seq, Tensor): + num_dim = prefix_seq.dim() + + if num_dim >= 2: + raise ValueError( + f"`prefix_seq` must be a scalar or a 1-dimensional tensor, but is {num_dim}-dimensional instead." + ) + + self.prefix_seq_len = 1 if num_dim == 0 else prefix_seq.size(0) + else: + self.prefix_seq_len = 1 + + # Set beam search. + self.search = self.opts.search or StandardBeamSearch() + self.logits_processor = self.opts.logits_processor + + if vocab_info.pad_idx is None: + self.collater = Collater() + else: + self.collater = Collater(self.pad_idx, pad_to_multiple=2) + + @torch.inference_mode() + def __call__( + self, + encoder_output: Tensor, + encoder_padding_mask: Optional[Tensor], + source_seq_len: Optional[int] = None, + ) -> "SequenceGeneratorOutput": + opts = self.opts + + num_searches = encoder_output.size(0) + + beam_size = opts.beam_size + + max_seq_len = self._determine_max_seq_len(source_seq_len) + + device = encoder_output.device + + encoder_output, encoder_padding_mask = self._fan_out_encoder_output( + encoder_output, encoder_padding_mask + ) + + # Each element contains the id of the search corresponding to a single + # source sequence and its hypotheses. + active_searches: List[Tuple[int, List[Hypothesis]]] = [ + (search_idx, []) for search_idx in range(num_searches) + ] + + # Once a source sequence has `beam_size` hypotheses, its search is moved + # from `active_searches` to `finished_searches`. + finished_searches: List[List[Hypothesis]] = [[] for i in range(num_searches)] + + num_remaining_searches = num_searches + + # Initialize buffers. + # (N x B, S) + seqs = torch.zeros( + (num_searches * beam_size, max_seq_len), device=device, dtype=torch.int64 + ) + + # (N x B, S) + scores = torch.zeros( + (num_searches * beam_size, max_seq_len), device=device, dtype=torch.float32 + ) + + # A list that indicates beams that should be ignored in the next step. + ignored_beam_mask = torch.full( + (num_searches, beam_size), False, device=device, dtype=torch.bool + ) + + # An offset array for converting between batch-wide and search-local + # beam indices. + # (B) + search_offsets = torch.arange(num_searches, device=device) * beam_size + + # (B) -> (B, 1) + search_offsets.unsqueeze_(-1) + + cand_offsets = torch.arange(2 * beam_size, device=device) + + state_bag = IncrementalStateBag() + + # At this point, the state is fully initialized, kick off the search. + self._bootstrap_seqs_and_scores( + seqs, scores, encoder_output, encoder_padding_mask, state_bag + ) + + start_step = self.prefix_seq_len - 1 + + # Holds the indices of beams (a beam can occur more than once) that we + # should continue with in the next step. + beam_indices: Optional[Tensor] = None + + # Holds the indices of searches that we should continue with in the next + # step. If not `None`, it means we finalized one or more searches in the + # last step. + search_indices: Optional[Tensor] = None + + for step_nr in range(start_step, max_seq_len - 1): + if beam_indices is not None: + # If not `None`, it means in the last step we finalized one or + # more searches. We should ensure that we adjust `beam_indices` + # before reordering `decoder`'s incremental state. + if search_indices is not None: + num_searches = search_indices.numel() + + # (N) + delta = search_indices - torch.arange(num_searches, device=device) + + # (N) -> (N, 1) + delta.unsqueeze_(-1) + + # Adjust indices to take into account removed searches. + beam_indices.view(num_searches, beam_size).add_(delta * beam_size) + + state_bag.reorder(beam_indices) + + decoder_output, decoder_padding_mask = self.decoder.decode( + seqs[:, step_nr : step_nr + 1], + None, # We never generate PAD. + encoder_output, + encoder_padding_mask, + state_bag, + ) + + state_bag.increment_step() + + model_output = self.decoder.project(decoder_output, decoder_padding_mask) + + # lprobs: (1, V) + # model_output: (N, 1, V) + lprobs = log_softmax(model_output.logits, dim=-1, dtype=torch.float32) + + # Do not allow EOS before reaching the minimum sequence length. + if step_nr < self.opts.min_seq_len: + lprobs[:, :, self.eos_idx] = -torch.inf + + # fmt: off + # If we have reached the maximum length, force the last step to be + # EOS. + if step_nr == max_seq_len - 2: + lprobs[:, :, : self.eos_idx] = -torch.inf + lprobs[:, :, self.eos_idx + 1 :] = -torch.inf + # fmt: on + + # Never allow PAD. + if self.pad_idx is not None: + lprobs[:, :, self.pad_idx] = -torch.inf + + # Apply UNK penalty. + if self.unk_idx is not None: + lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty + + # update scores in place using logits_processor + if self.logits_processor is not None: + self.logits_processor( + seqs.view(num_searches, beam_size, -1)[:, :, : step_nr + 1], + lprobs.view(num_searches, beam_size, -1), + ) + + # Determine candidates for the next step. + # (N, 2 x B) + cand_scores, cand_indices, cand_beam_indices = self.search.step( + step_nr, + step_nr == start_step, + lprobs.view(num_searches, beam_size, -1), + scores.view(num_searches, beam_size, -1)[:, :, : step_nr + 1], + ) + + # Convert search-local beam indices to batch-wide beam indices. + # (N, 2 x B) + (N) -> (N, 2 x B) + global_cand_beam_indices = cand_beam_indices + search_offsets + + # Finalize beams that reached the minimum length and that end with + # an EOS. + # (N, 2 x B) + eos_mask = (cand_indices == self.eos_idx) & (cand_scores != -math.inf) + + # Do not attempt to finalize beams that should be ignored. + eos_mask[:, :beam_size][ignored_beam_mask] = False + + # Only consider EOS when it's among the top `beam_size` indices. Now + # we know what beam(s) to finalize. + # (N, B) + eos_beam_indices = torch.masked_select( + global_cand_beam_indices[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + if eos_beam_indices.numel() > 0: + # Select the scores of the finalized beams. + # (N, B) + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + newly_finished_searches = self._finalize_hypothesis( + step_nr, + eos_beam_indices, + eos_scores, + seqs, + scores, + active_searches, + finished_searches, + ) + + num_remaining_searches -= len(newly_finished_searches) + + if num_remaining_searches == 0: + break + else: + newly_finished_searches = None + + # Remove finished searches (ones for which `beam_size` finalized + # beams have been generated) from the batch. + if newly_finished_searches: + new_num_searches = num_searches - len(newly_finished_searches) + + # Construct `search_indices` which holds indices of searches + # to keep for the next step. + search_mask = torch.full((num_searches,), True, device=device) + + search_mask[newly_finished_searches] = False + + search_indices = torch.arange(num_searches, device=device) + + search_indices = search_indices.masked_select(search_mask) + + # fmt: off + # Filter out removed batches from state variables. + # (N, B) -> (N - F, B) + ignored_beam_mask = ignored_beam_mask[search_indices] + + # (N, 2 x B) -> (N - F, 2 x B) + cand_scores = cand_scores [search_indices] + cand_indices = cand_indices [search_indices] + cand_beam_indices = cand_beam_indices[search_indices] + + # (N) -> (N - F) + search_offsets.resize_(new_num_searches, 1) + + # (N - F, 2 x B) + (N - F) -> (N - F, 2 x B) + global_cand_beam_indices = cand_beam_indices + search_offsets + + # (N, 2 x B) -> (N - F, 2 x B) + eos_mask = eos_mask[search_indices] + + # (N x B, S) -> (N, B, S) + seqs = seqs .view(num_searches, -1) + scores = scores.view(num_searches, -1) + + # (N, B, S + 1) -> ((N - F) x B, S) + seqs = seqs [search_indices].view(new_num_searches * beam_size, -1) + scores = scores[search_indices].view(new_num_searches * beam_size, -1) + + # (N x B, S_enc, M) -> (N, B, S_enc, M) + encoder_output = encoder_output.unflatten(0, (num_searches, -1)) + + # (N, B, S_enc, M) -> ((N - F) x B, S_enc, M) + encoder_output = encoder_output[search_indices].flatten(0, 1) + + if encoder_padding_mask is not None: + # (N x B, S_enc, M) -> (N, B, S_enc, M) + padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1)) + + # (N, B, S_enc, M) -> ((N - F) x B, S_enc, M) + encoder_padding_mask = padding_mask[search_indices].flatten(0, 1) + # fmt: on + + num_searches = new_num_searches + else: + search_indices = None + + eos_mask[:, :beam_size][ignored_beam_mask] = True + + # Set `beam_weights` so that values greater than or equal to 2 x + # `beam_size` indicate finished beams (i.e. end with EOS) and values + # less than 2 x `beam_size` indicate active beams. + # (N, 2 x B) + beam_weights = cand_offsets + (eos_mask * (2 * beam_size)) + + # Get the top `beam_size` active beams, which are the beams with the + # smallest weights in `active_beam_weights`. + # (N, B) + active_beam_weights, active_beams = torch.topk( + beam_weights, k=beam_size, dim=1, largest=False + ) + + # Update to ignore finalized beams in the next step. + # (N, B) + ignored_beam_mask = active_beam_weights >= 2 * beam_size + + # We should always have at least one active beam in each search. + assert (~ignored_beam_mask).any(dim=1).all() + + # Denotes which beams are continued for each new hypothesis (a beam + # can be selected more than once). + # (N, B) + beam_indices = torch.gather( + global_cand_beam_indices, dim=1, index=active_beams + ) + + # (N, B) -> (N x B) + beam_indices = beam_indices.view(-1) + + # fmt: off + # Reorder beams in the `seq` and `score` buffers. The same beam can + # be selected more than once. + if step_nr > start_step: + seqs [:, : step_nr + 1] = torch.index_select( + seqs [:, : step_nr + 1], dim=0, index=beam_indices + ) + scores[:, : step_nr + 1] = torch.index_select( + scores[:, : step_nr + 1], dim=0, index=beam_indices + ) + + # (N x B, S) -> (N, B, S) + seqs_view = seqs .view(num_searches, beam_size, -1) + scores_view = scores.view(num_searches, beam_size, -1) + + seqs_view [:, :, step_nr + 1] = torch.gather(cand_indices, dim=1, index=active_beams) + scores_view[:, :, step_nr + 1] = torch.gather(cand_scores, dim=1, index=active_beams) + # fmt: on + + # Ensure that hypotheses are sorted by their scores before returning. + for batch in finished_searches: + batch.sort(key=lambda b: b.score, reverse=True) # type: ignore[arg-type, return-value] + + return SequenceGeneratorOutput( + results=finished_searches, device=device, collater=self.collater + ) + + def _determine_max_seq_len(self, source_seq_len: Optional[int]) -> int: + opts = self.opts + + if source_seq_len is None or opts.soft_max_seq_len is None: + max_seq_len = opts.hard_max_seq_len + else: + at, bt = opts.soft_max_seq_len + + max_seq_len = min(opts.hard_max_seq_len, int(at * source_seq_len + bt)) + + if opts.min_seq_len > max_seq_len: + raise ValueError( + f"The effective maximum sequence length must be greater than or equal to `min_seq_len` ({opts.min_seq_len}), but is {max_seq_len} instead. Adjust your soft and hard maximum sequence length limits." + ) + + if self.prefix_seq_len >= max_seq_len: + raise ValueError( + f"The effective maximum sequence length must be greater than `prefix_seq_len` ({self.prefix_seq_len}), but is {max_seq_len} instead." + ) + + return max_seq_len + + def _fan_out_encoder_output( + self, encoder_output: Tensor, encoder_padding_mask: Optional[Tensor] + ) -> Tuple[Tensor, Optional[Tensor]]: + num_searches = encoder_output.size(0) # i.e. batch size + + # Fan out `encoder_output` to `num_searches` x `beam_size`. + # (N) + fan_out_indices = torch.arange(num_searches, device=encoder_output.device) + + # (N) -> (N x B) + fan_out_indices = fan_out_indices.repeat_interleave(self.beam_size) + + # (N, S_enc, M) -> (N x B, S_enc, M) + encoder_output = encoder_output.index_select(dim=0, index=fan_out_indices) + + # (N, S_enc, M) -> (N x B, S_enc, M) + if encoder_padding_mask is not None: + encoder_padding_mask = encoder_padding_mask.index_select( + dim=0, index=fan_out_indices + ) + + return encoder_output, encoder_padding_mask + + def _bootstrap_seqs_and_scores( + self, + seqs: Tensor, + scores: Tensor, + encoder_output: Tensor, + encoder_padding_mask: Optional[Tensor], + state_bag: IncrementalStateBag, + ) -> None: + assert self.prefix_seq_len > 0 + + seqs[:, : self.prefix_seq_len] = self.prefix_seq + + if self.prefix_seq_len == 1: + return + + assert isinstance(self.prefix_seq, Tensor) + + # We have to bootstrap the model with the already fanned-out encoder + # output to correctly initialize its incremental state. This causes some + # redundancy as we have to expand `decoder_input` to match the shape of + # `encoder_output`. + # (S_pfx) -> (N x B, S_pfx - 1) + decoder_input = self.prefix_seq[:-1].expand(encoder_output.size(0), -1) + + # Bootstrap the model state with prefix sequence. + decoder_output, decoder_padding_mask = self.decoder.decode( + decoder_input, + None, + encoder_output, + encoder_padding_mask, + state_bag, + ) + + state_bag.increment_step(self.prefix_seq_len - 1) + + model_output = self.decoder.project(decoder_output, decoder_padding_mask) + + # lprobs: (S_pfx - 1, V) + # model_output: (N, S_pfx - 1, V) -> (S_pfx - 1, V) + lprobs = log_softmax(model_output.logits[0], dim=-1, dtype=torch.float32) + + # Fetch scores of next steps. + # (S_pfx - 1, 1) + prefix_scores = torch.take_along_dim( + lprobs, indices=self.prefix_seq[1:].unsqueeze(1), dim=-1 + ) + + # (S_pfx - 1, 1) -> (S_pfx - 1) + prefix_scores.squeeze_(1).cumsum_(dim=0) + + # First step (e.g. EOS)'s score is always 0. + scores[:, 1 : self.prefix_seq_len] = prefix_scores + + def _finalize_hypothesis( + self, + step_nr: int, + eos_beam_indices: Tensor, + eos_scores: Tensor, + seqs: Tensor, + scores: Tensor, + active_searches: List[Tuple[int, List["Hypothesis"]]], + finished_searches: List[List["Hypothesis"]], + ) -> List[int]: + # fmt: off + finalized_seqs = seqs .index_select(dim=0, index=eos_beam_indices) + finalized_scores = scores.index_select(dim=0, index=eos_beam_indices) + + finalized_seqs = finalized_seqs [:, : step_nr + 2] + finalized_scores = finalized_scores[:, : step_nr + 2] + + # Finalize beams. + finalized_seqs [:, -1] = self.eos_idx + finalized_scores[:, -1] = eos_scores + # fmt: on + + # Convert from cumulative to per-step scores. + finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1] + + # Skip first EOS since it is always 0 and skews normalization. + if self.opts.normalize_scores: + eos_scores /= (step_nr + 1) ** self.opts.len_penalty + + # Holds the ids of finished searches. + newly_finished: List[int] = [] + + active_search_indices = (eos_beam_indices // self.beam_size).tolist() + + for beam_idx, search_idx in enumerate(active_search_indices): + search_id, hypotheses = active_searches[search_idx] + + # We might have more than one beam finalized in one step that would + # potentially exceed `beam_size` hypotheses. + if len(hypotheses) == self.beam_size: + continue + + hypotheses.append( + Hypothesis( + seq=finalized_seqs[beam_idx], + score=eos_scores[beam_idx], + step_scores=finalized_scores[beam_idx], + ) + ) + + if len(hypotheses) == self.beam_size: + # We have `beam_size` hypotheses for this particular search, so + # we finish it now. + newly_finished.append(search_idx) + + finished_searches[search_id] = hypotheses + + newly_finished.sort() + + # Remove finished searches from the active list. + for idx in reversed(newly_finished): + del active_searches[idx] + + return newly_finished + + +@dataclass +class SequenceGeneratorOutput: + """Holds the output of a sequence generator.""" + + results: List[List["Hypothesis"]] + """The list of hypothesis generated per search, ordered by score.""" + + device: Device + """The device on which generated sequences reside.""" + + collater: Optional[Collater] = None + """The collater to use in :meth:`collate`.""" + + def collate( + self, hypo_idx: int = 0, skip_batch: bool = False + ) -> Tuple[Tensor, Optional[Tensor]]: + """Collate the generated sequences at index ``hypo_idx`` in each search + result into a single tensor. + + :param hypo_idx: + The index of hypothesis to extract from each search result. + :param skip_batch: + If ``True``, if a search result has no hypothesis at index `hypo_idx`, + it will be skipped instead of raising an error. + + :returns: + - The collated sequences. *Shape:* :math:`(N,S)`, where :math:`N` is + the number of search results and :math:`S` is the sequence length. + - An array where each element represents the length of the sequence at + the same index in the first returned value. *Shape:* :math:`(N)`, + where :math:`N` is the number of search results. + """ + if self.collater is None: + raise RuntimeError("The output has no associated `Collater` instance.") + + if not self.results and not skip_batch: + raise ValueError("The output must contain at least one search result.") + + seqs = [] + + for search_idx, result in enumerate(self.results): + if hypo_idx >= len(result): + if not skip_batch: + raise ValueError( + f"Each search result must have at least {hypo_idx + 1} hypotheses, but search {search_idx} has only {len(result)}." + ) + + continue + + seqs.append(result[hypo_idx].seq) + + if not seqs: + # Return a zero-dimensional (not scalar!) tensor. + return torch.empty((0,), device=self.device, dtype=torch.int64), None + + output = cast(SequenceData, self.collater(seqs)) + + return output["seqs"], output["seq_lens"] if output["is_ragged"] else None + + +@dataclass +class Hypothesis: + """Represents a hypothesis produced by a sequence generator.""" + + seq: Tensor + """The generated sequence.""" + + score: Tensor + """The score of the hypothesis.""" + + step_scores: Tensor + """The score of each individual sequence step.""" diff --git a/egs/aishell/ASR/seamlessm4t/train.py b/egs/aishell/ASR/seamlessm4t/train.py index ee08cecc6..4802473c9 100644 --- a/egs/aishell/ASR/seamlessm4t/train.py +++ b/egs/aishell/ASR/seamlessm4t/train.py @@ -110,10 +110,8 @@ from fairseq2.data.text import ( TextTokenizer, vocabulary_from_sentencepiece, ) -from tokenizer import CharTokenizer + from label_smoothing import LabelSmoothingLoss -from fairseq2.nn.embedding import Embedding -from fairseq2.nn.projection import TiedProjection LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -603,7 +601,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - text_tokenizer_encoder: CharTokenizer, + text_tokenizer_encoder: SentencePieceEncoder, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -661,10 +659,8 @@ def compute_loss( batch_idx_train = params.batch_idx_train warm_step = params.warm_step - texts = batch["supervisions"]["text"] - # remove spaces in the text - texts = [text.replace(" ", "") for text in texts] - text_tokens_list = [torch.tensor([params.eos_idx] + text_tokenizer_encoder.encode(text) + [params.eos_idx]) for text in texts] + texts = batch["supervisions"]["text"] + text_tokens_list = [text_tokenizer_encoder(text) for text in texts] prev_outputs_tokens = _batch_tensors( [tokens[:-1] for tokens in text_tokens_list], pad_value=params.pad_idx ) @@ -710,7 +706,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - text_tokenizer_encoder: CharTokenizer, + text_tokenizer_encoder: SentencePieceEncoder, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -746,7 +742,7 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - text_tokenizer_encoder: CharTokenizer, + text_tokenizer_encoder: SentencePieceEncoder, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -921,6 +917,7 @@ def train_one_epoch( params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss + def run(rank, world_size, args): """ Args: @@ -960,28 +957,22 @@ def run(rank, world_size, args): logging.info("About to create model") model_name_or_card = "seamlessM4T_medium" - tokenizer_file = "./seamlessm4t/tokens.txt" lang = "cmn" - - # text_tokenizer = load_unity_text_tokenizer(model_name_or_card) - # text_tokenizer_encoder = SentencePieceEncoder( - # text_tokenizer.model, - # prefix_tokens=["", f"__{lang}__"], - # suffix_tokens=[""], - # ) - # #params.eos_idx = text_tokenizer.model.eos_idx - # params.pad_idx = text_tokenizer.model.pad_idx - text_tokenizer_encoder = CharTokenizer(tokenizer_file) - params.pad_idx, params.eos_idx = 0, 1 - logging.info(params) - model = load_unity_model(model_name_or_card, device="cpu", dtype=torch.float32) del model.t2u_model del model.text_encoder del model.text_encoder_frontend - model.text_decoder_frontend.embed = Embedding(num_embeddings=text_tokenizer_encoder.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True) - #model.final_proj = TiedProjection(input_dim=1024, output_dim=text_tokenizer_encoder.vocab_size) - model.final_proj = nn.Linear(1024, text_tokenizer_encoder.vocab_size) + # print(vars(model)) + # exit(0) + text_tokenizer = load_unity_text_tokenizer(model_name_or_card) + text_tokenizer_encoder = SentencePieceEncoder( + text_tokenizer.model, + prefix_tokens=["", f"__{lang}__"], + suffix_tokens=[""], + ) + #params.eos_idx = text_tokenizer.model.eos_idx + params.pad_idx = text_tokenizer.model.pad_idx + logging.info(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -1206,7 +1197,7 @@ def scan_pessimistic_batches_for_oom( train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, params: AttributeDict, - text_tokenizer_encoder: CharTokenizer, + text_tokenizer_encoder: SentencePieceEncoder, ): from lhotse.dataset import find_pessimistic_batches diff --git a/egs/aishell/ASR/seamlessm4t/train2.py b/egs/aishell/ASR/seamlessm4t/train2.py index 4802473c9..9d5cf4ab9 100644 --- a/egs/aishell/ASR/seamlessm4t/train2.py +++ b/egs/aishell/ASR/seamlessm4t/train2.py @@ -110,8 +110,10 @@ from fairseq2.data.text import ( TextTokenizer, vocabulary_from_sentencepiece, ) - +from tokenizer import CharTokenizer from label_smoothing import LabelSmoothingLoss +from fairseq2.nn.embedding import Embedding +from fairseq2.nn.projection import TiedProjection LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -601,7 +603,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - text_tokenizer_encoder: SentencePieceEncoder, + text_tokenizer_encoder: CharTokenizer, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -659,8 +661,10 @@ def compute_loss( batch_idx_train = params.batch_idx_train warm_step = params.warm_step - texts = batch["supervisions"]["text"] - text_tokens_list = [text_tokenizer_encoder(text) for text in texts] + texts = batch["supervisions"]["text"] + # remove spaces in the text + texts = [text.replace(" ", "") for text in texts] + text_tokens_list = [torch.tensor([params.eos_idx] + text_tokenizer_encoder.encode(text) + [params.eos_idx]) for text in texts] prev_outputs_tokens = _batch_tensors( [tokens[:-1] for tokens in text_tokens_list], pad_value=params.pad_idx ) @@ -706,7 +710,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - text_tokenizer_encoder: SentencePieceEncoder, + text_tokenizer_encoder: CharTokenizer, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -742,7 +746,7 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - text_tokenizer_encoder: SentencePieceEncoder, + text_tokenizer_encoder: CharTokenizer, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -917,7 +921,6 @@ def train_one_epoch( params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss - def run(rank, world_size, args): """ Args: @@ -957,22 +960,42 @@ def run(rank, world_size, args): logging.info("About to create model") model_name_or_card = "seamlessM4T_medium" + tokenizer_file = "./seamlessm4t/tokens.txt" lang = "cmn" + + # text_tokenizer = load_unity_text_tokenizer(model_name_or_card) + # text_tokenizer_encoder = SentencePieceEncoder( + # text_tokenizer.model, + # prefix_tokens=["", f"__{lang}__"], + # suffix_tokens=[""], + # ) + # #params.eos_idx = text_tokenizer.model.eos_idx + # params.pad_idx = text_tokenizer.model.pad_idx + text_tokenizer_encoder = CharTokenizer(tokenizer_file) + params.pad_idx, params.eos_idx = 0, 1 + logging.info(params) + model = load_unity_model(model_name_or_card, device="cpu", dtype=torch.float32) del model.t2u_model del model.text_encoder del model.text_encoder_frontend - # print(vars(model)) - # exit(0) - text_tokenizer = load_unity_text_tokenizer(model_name_or_card) - text_tokenizer_encoder = SentencePieceEncoder( - text_tokenizer.model, - prefix_tokens=["", f"__{lang}__"], - suffix_tokens=[""], - ) - #params.eos_idx = text_tokenizer.model.eos_idx - params.pad_idx = text_tokenizer.model.pad_idx - logging.info(params) + model.text_decoder_frontend.embed = nn.Embedding(num_embeddings=text_tokenizer_encoder.vocab_size, embedding_dim=1024 ,padding_idx=0) + #model.text_decoder_frontend.embed = Embedding(num_embeddings=text_tokenizer_encoder.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True) + #model.final_proj = TiedProjection(input_dim=1024, output_dim=text_tokenizer_encoder.vocab_size) + model.final_proj = nn.Linear(1024, text_tokenizer_encoder.vocab_size, bias=False) + for name, param in model.named_parameters(): + if name != 'text_decoder_frontend.embed.weight' and name != 'final_proj.weight': + #param.requires_grad = False + pass + model.text_decoder_frontend.embed.requires_grad = True + model.final_proj.requires_grad = True + print(model.text_decoder_frontend.embed.requires_grad, model.final_proj.requires_grad) + for param in model.parameters(): + if param.requires_grad: + print(233333333333333333333333333333333333333333333333333333333333333333333) + for name, param in model.named_parameters(): + print(name, param.requires_grad) + #exit(0) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -1197,7 +1220,7 @@ def scan_pessimistic_batches_for_oom( train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, params: AttributeDict, - text_tokenizer_encoder: SentencePieceEncoder, + text_tokenizer_encoder: CharTokenizer, ): from lhotse.dataset import find_pessimistic_batches