mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
rename train, train2, add support to fine-tune embedding table
This commit is contained in:
parent
d926585b10
commit
bb1c4466e3
@ -191,12 +191,12 @@ def decode_one_batch(
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
|
||||
#text_output = s2t_generator.generate_ex(feature, feature_len)
|
||||
text_output = s2t_generator.generate_ex(feature, feature_len)
|
||||
#sentences = text_output.sentences
|
||||
#hyps = [sentence.bytes().decode("utf-8").split() for sentence in sentences]
|
||||
|
||||
token_ids = text_output.generator_output.results[0][0].seq.cpu().tolist()
|
||||
hyps_ids = [setence[0].seq.cpu().tolist() for sentence in token_ids]
|
||||
token_ids = text_output.generator_output.results
|
||||
hyps_ids = [sentence[0].seq.cpu().tolist() for sentence in token_ids]
|
||||
hyps = [params.tokenizer.decode(hyps_id).split() for hyps_id in hyps_ids]
|
||||
|
||||
key = "beam-search"
|
||||
@ -347,8 +347,10 @@ def main():
|
||||
del model.t2u_model
|
||||
del model.text_encoder
|
||||
del model.text_encoder_frontend
|
||||
model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenzier.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
||||
model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size)
|
||||
model.text_decoder_frontend.embed = nn.Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,padding_idx=0)
|
||||
#model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
||||
model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size, bias=False)
|
||||
#model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size)
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
@ -371,6 +373,13 @@ def main():
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.half()
|
||||
#for param in model.parameters():
|
||||
# if param.dtype == torch.float16:
|
||||
# pass
|
||||
# else:
|
||||
# param.data = param.data.to(torch.float16)
|
||||
#print(param)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
694
egs/aishell/ASR/seamlessm4t/patch/sequence_generator.py
Normal file
694
egs/aishell/ASR/seamlessm4t/patch/sequence_generator.py
Normal file
@ -0,0 +1,694 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.functional import log_softmax
|
||||
|
||||
from fairseq2.data import Collater, SequenceData, VocabularyInfo
|
||||
from fairseq2.generation.beam_search import BeamSearch, StandardBeamSearch
|
||||
from fairseq2.generation.logits_processor import LogitsProcessor
|
||||
from fairseq2.models.encoder_decoder import Seq2SeqDecoder
|
||||
from fairseq2.nn.incremental_state import IncrementalStateBag
|
||||
from fairseq2.typing import Device
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGeneratorOptions:
|
||||
"""Holds the options to pass to a sequence generator."""
|
||||
|
||||
beam_size: int = 5
|
||||
"""The beam size."""
|
||||
|
||||
min_seq_len: int = 1
|
||||
"""The minimum length of generated sequences (including prefix sequence)."""
|
||||
|
||||
soft_max_seq_len: Optional[Tuple[int, int]] = (1, 200)
|
||||
"""The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source
|
||||
sequence length. The generated sequences (including prefix sequence) will
|
||||
have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also
|
||||
``hard_max_seq_len``."""
|
||||
|
||||
hard_max_seq_len: int = 1024
|
||||
"""The hard limit on maximum length of generated sequences."""
|
||||
|
||||
len_penalty: float = 1.0
|
||||
"""The length penalty, where values less than 1.0 favor shorter, values
|
||||
greater than 1.0 favor longer sequences."""
|
||||
|
||||
unk_penalty: float = 0.0
|
||||
"""The unknown symbol penalty, where values less than 0 produce more UNKs,
|
||||
values greater than 0 produce fewer UNKs."""
|
||||
|
||||
normalize_scores: bool = True
|
||||
"""If ``True``, normalizes scores by the length of generated sequences."""
|
||||
|
||||
search: Optional[BeamSearch] = None
|
||||
"""The beam search algorithm to use."""
|
||||
|
||||
logits_processor: Optional[LogitsProcessor] = None
|
||||
"""Logits processor called before applying beam search step."""
|
||||
|
||||
|
||||
class Seq2SeqGenerator:
|
||||
"""Represents a sequence-to-sequence generator."""
|
||||
|
||||
decoder: Seq2SeqDecoder
|
||||
opts: SequenceGeneratorOptions
|
||||
beam_size: int
|
||||
eos_idx: int
|
||||
pad_idx: Optional[int]
|
||||
unk_idx: Optional[int]
|
||||
prefix_seq: Union[int, Tensor]
|
||||
prefix_seq_len: int
|
||||
search: BeamSearch
|
||||
logits_processor: Optional[LogitsProcessor]
|
||||
collater: Collater
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: Seq2SeqDecoder,
|
||||
vocab_info: VocabularyInfo,
|
||||
prefix_seq: Optional[Union[int, Tensor]],
|
||||
opts: Optional[SequenceGeneratorOptions] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param decoder:
|
||||
The decoder to use.
|
||||
:param vocab_info:
|
||||
The vocabulary information to use.
|
||||
:param prefix_seq:
|
||||
The prefix sequence, typically one or more control symbols
|
||||
indicating the beginning of a sequence. *Shape:* :math:`()` or
|
||||
:math:`(S)`, where :math:`S` is the sequence length. If ``None``,
|
||||
the EOS symbol will be used as prefix.
|
||||
:param opts:
|
||||
The generation options.
|
||||
"""
|
||||
self.decoder = decoder
|
||||
|
||||
self.opts = opts or SequenceGeneratorOptions()
|
||||
|
||||
# Set beam size.
|
||||
if vocab_info.pad_idx is None:
|
||||
self.beam_size = min(self.opts.beam_size, vocab_info.size)
|
||||
else:
|
||||
# -1 since we never select PAD.
|
||||
self.beam_size = min(self.opts.beam_size, vocab_info.size - 1)
|
||||
|
||||
if vocab_info.eos_idx is None:
|
||||
raise ValueError(
|
||||
"`vocab_info` must have `eos_idx` set for sequence generation."
|
||||
)
|
||||
|
||||
# Set vocab info.
|
||||
self.eos_idx = 1
|
||||
#self.eos_idx = vocab_info.eos_idx
|
||||
self.unk_idx = 2
|
||||
#self.unk_idx = vocab_info.unk_idx
|
||||
self.pad_idx = 0
|
||||
#self.pad_idx = vocab_info.pad_idx
|
||||
|
||||
# Set prefix sequence.
|
||||
if 1:
|
||||
#if prefix_seq is None:
|
||||
# If `None`, we follow fairseq's convention, and use EOS as the
|
||||
# prefix.
|
||||
self.prefix_seq, self.prefix_seq_len = self.eos_idx, 1
|
||||
else:
|
||||
self.prefix_seq = prefix_seq
|
||||
|
||||
if isinstance(prefix_seq, Tensor):
|
||||
num_dim = prefix_seq.dim()
|
||||
|
||||
if num_dim >= 2:
|
||||
raise ValueError(
|
||||
f"`prefix_seq` must be a scalar or a 1-dimensional tensor, but is {num_dim}-dimensional instead."
|
||||
)
|
||||
|
||||
self.prefix_seq_len = 1 if num_dim == 0 else prefix_seq.size(0)
|
||||
else:
|
||||
self.prefix_seq_len = 1
|
||||
|
||||
# Set beam search.
|
||||
self.search = self.opts.search or StandardBeamSearch()
|
||||
self.logits_processor = self.opts.logits_processor
|
||||
|
||||
if vocab_info.pad_idx is None:
|
||||
self.collater = Collater()
|
||||
else:
|
||||
self.collater = Collater(self.pad_idx, pad_to_multiple=2)
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self,
|
||||
encoder_output: Tensor,
|
||||
encoder_padding_mask: Optional[Tensor],
|
||||
source_seq_len: Optional[int] = None,
|
||||
) -> "SequenceGeneratorOutput":
|
||||
opts = self.opts
|
||||
|
||||
num_searches = encoder_output.size(0)
|
||||
|
||||
beam_size = opts.beam_size
|
||||
|
||||
max_seq_len = self._determine_max_seq_len(source_seq_len)
|
||||
|
||||
device = encoder_output.device
|
||||
|
||||
encoder_output, encoder_padding_mask = self._fan_out_encoder_output(
|
||||
encoder_output, encoder_padding_mask
|
||||
)
|
||||
|
||||
# Each element contains the id of the search corresponding to a single
|
||||
# source sequence and its hypotheses.
|
||||
active_searches: List[Tuple[int, List[Hypothesis]]] = [
|
||||
(search_idx, []) for search_idx in range(num_searches)
|
||||
]
|
||||
|
||||
# Once a source sequence has `beam_size` hypotheses, its search is moved
|
||||
# from `active_searches` to `finished_searches`.
|
||||
finished_searches: List[List[Hypothesis]] = [[] for i in range(num_searches)]
|
||||
|
||||
num_remaining_searches = num_searches
|
||||
|
||||
# Initialize buffers.
|
||||
# (N x B, S)
|
||||
seqs = torch.zeros(
|
||||
(num_searches * beam_size, max_seq_len), device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
# (N x B, S)
|
||||
scores = torch.zeros(
|
||||
(num_searches * beam_size, max_seq_len), device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
# A list that indicates beams that should be ignored in the next step.
|
||||
ignored_beam_mask = torch.full(
|
||||
(num_searches, beam_size), False, device=device, dtype=torch.bool
|
||||
)
|
||||
|
||||
# An offset array for converting between batch-wide and search-local
|
||||
# beam indices.
|
||||
# (B)
|
||||
search_offsets = torch.arange(num_searches, device=device) * beam_size
|
||||
|
||||
# (B) -> (B, 1)
|
||||
search_offsets.unsqueeze_(-1)
|
||||
|
||||
cand_offsets = torch.arange(2 * beam_size, device=device)
|
||||
|
||||
state_bag = IncrementalStateBag()
|
||||
|
||||
# At this point, the state is fully initialized, kick off the search.
|
||||
self._bootstrap_seqs_and_scores(
|
||||
seqs, scores, encoder_output, encoder_padding_mask, state_bag
|
||||
)
|
||||
|
||||
start_step = self.prefix_seq_len - 1
|
||||
|
||||
# Holds the indices of beams (a beam can occur more than once) that we
|
||||
# should continue with in the next step.
|
||||
beam_indices: Optional[Tensor] = None
|
||||
|
||||
# Holds the indices of searches that we should continue with in the next
|
||||
# step. If not `None`, it means we finalized one or more searches in the
|
||||
# last step.
|
||||
search_indices: Optional[Tensor] = None
|
||||
|
||||
for step_nr in range(start_step, max_seq_len - 1):
|
||||
if beam_indices is not None:
|
||||
# If not `None`, it means in the last step we finalized one or
|
||||
# more searches. We should ensure that we adjust `beam_indices`
|
||||
# before reordering `decoder`'s incremental state.
|
||||
if search_indices is not None:
|
||||
num_searches = search_indices.numel()
|
||||
|
||||
# (N)
|
||||
delta = search_indices - torch.arange(num_searches, device=device)
|
||||
|
||||
# (N) -> (N, 1)
|
||||
delta.unsqueeze_(-1)
|
||||
|
||||
# Adjust indices to take into account removed searches.
|
||||
beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
|
||||
|
||||
state_bag.reorder(beam_indices)
|
||||
|
||||
decoder_output, decoder_padding_mask = self.decoder.decode(
|
||||
seqs[:, step_nr : step_nr + 1],
|
||||
None, # We never generate PAD.
|
||||
encoder_output,
|
||||
encoder_padding_mask,
|
||||
state_bag,
|
||||
)
|
||||
|
||||
state_bag.increment_step()
|
||||
|
||||
model_output = self.decoder.project(decoder_output, decoder_padding_mask)
|
||||
|
||||
# lprobs: (1, V)
|
||||
# model_output: (N, 1, V)
|
||||
lprobs = log_softmax(model_output.logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
# Do not allow EOS before reaching the minimum sequence length.
|
||||
if step_nr < self.opts.min_seq_len:
|
||||
lprobs[:, :, self.eos_idx] = -torch.inf
|
||||
|
||||
# fmt: off
|
||||
# If we have reached the maximum length, force the last step to be
|
||||
# EOS.
|
||||
if step_nr == max_seq_len - 2:
|
||||
lprobs[:, :, : self.eos_idx] = -torch.inf
|
||||
lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
|
||||
# fmt: on
|
||||
|
||||
# Never allow PAD.
|
||||
if self.pad_idx is not None:
|
||||
lprobs[:, :, self.pad_idx] = -torch.inf
|
||||
|
||||
# Apply UNK penalty.
|
||||
if self.unk_idx is not None:
|
||||
lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
|
||||
|
||||
# update scores in place using logits_processor
|
||||
if self.logits_processor is not None:
|
||||
self.logits_processor(
|
||||
seqs.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
|
||||
lprobs.view(num_searches, beam_size, -1),
|
||||
)
|
||||
|
||||
# Determine candidates for the next step.
|
||||
# (N, 2 x B)
|
||||
cand_scores, cand_indices, cand_beam_indices = self.search.step(
|
||||
step_nr,
|
||||
step_nr == start_step,
|
||||
lprobs.view(num_searches, beam_size, -1),
|
||||
scores.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
|
||||
)
|
||||
|
||||
# Convert search-local beam indices to batch-wide beam indices.
|
||||
# (N, 2 x B) + (N) -> (N, 2 x B)
|
||||
global_cand_beam_indices = cand_beam_indices + search_offsets
|
||||
|
||||
# Finalize beams that reached the minimum length and that end with
|
||||
# an EOS.
|
||||
# (N, 2 x B)
|
||||
eos_mask = (cand_indices == self.eos_idx) & (cand_scores != -math.inf)
|
||||
|
||||
# Do not attempt to finalize beams that should be ignored.
|
||||
eos_mask[:, :beam_size][ignored_beam_mask] = False
|
||||
|
||||
# Only consider EOS when it's among the top `beam_size` indices. Now
|
||||
# we know what beam(s) to finalize.
|
||||
# (N, B)
|
||||
eos_beam_indices = torch.masked_select(
|
||||
global_cand_beam_indices[:, :beam_size], mask=eos_mask[:, :beam_size]
|
||||
)
|
||||
|
||||
if eos_beam_indices.numel() > 0:
|
||||
# Select the scores of the finalized beams.
|
||||
# (N, B)
|
||||
eos_scores = torch.masked_select(
|
||||
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
||||
)
|
||||
|
||||
newly_finished_searches = self._finalize_hypothesis(
|
||||
step_nr,
|
||||
eos_beam_indices,
|
||||
eos_scores,
|
||||
seqs,
|
||||
scores,
|
||||
active_searches,
|
||||
finished_searches,
|
||||
)
|
||||
|
||||
num_remaining_searches -= len(newly_finished_searches)
|
||||
|
||||
if num_remaining_searches == 0:
|
||||
break
|
||||
else:
|
||||
newly_finished_searches = None
|
||||
|
||||
# Remove finished searches (ones for which `beam_size` finalized
|
||||
# beams have been generated) from the batch.
|
||||
if newly_finished_searches:
|
||||
new_num_searches = num_searches - len(newly_finished_searches)
|
||||
|
||||
# Construct `search_indices` which holds indices of searches
|
||||
# to keep for the next step.
|
||||
search_mask = torch.full((num_searches,), True, device=device)
|
||||
|
||||
search_mask[newly_finished_searches] = False
|
||||
|
||||
search_indices = torch.arange(num_searches, device=device)
|
||||
|
||||
search_indices = search_indices.masked_select(search_mask)
|
||||
|
||||
# fmt: off
|
||||
# Filter out removed batches from state variables.
|
||||
# (N, B) -> (N - F, B)
|
||||
ignored_beam_mask = ignored_beam_mask[search_indices]
|
||||
|
||||
# (N, 2 x B) -> (N - F, 2 x B)
|
||||
cand_scores = cand_scores [search_indices]
|
||||
cand_indices = cand_indices [search_indices]
|
||||
cand_beam_indices = cand_beam_indices[search_indices]
|
||||
|
||||
# (N) -> (N - F)
|
||||
search_offsets.resize_(new_num_searches, 1)
|
||||
|
||||
# (N - F, 2 x B) + (N - F) -> (N - F, 2 x B)
|
||||
global_cand_beam_indices = cand_beam_indices + search_offsets
|
||||
|
||||
# (N, 2 x B) -> (N - F, 2 x B)
|
||||
eos_mask = eos_mask[search_indices]
|
||||
|
||||
# (N x B, S) -> (N, B, S)
|
||||
seqs = seqs .view(num_searches, -1)
|
||||
scores = scores.view(num_searches, -1)
|
||||
|
||||
# (N, B, S + 1) -> ((N - F) x B, S)
|
||||
seqs = seqs [search_indices].view(new_num_searches * beam_size, -1)
|
||||
scores = scores[search_indices].view(new_num_searches * beam_size, -1)
|
||||
|
||||
# (N x B, S_enc, M) -> (N, B, S_enc, M)
|
||||
encoder_output = encoder_output.unflatten(0, (num_searches, -1))
|
||||
|
||||
# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
|
||||
encoder_output = encoder_output[search_indices].flatten(0, 1)
|
||||
|
||||
if encoder_padding_mask is not None:
|
||||
# (N x B, S_enc, M) -> (N, B, S_enc, M)
|
||||
padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1))
|
||||
|
||||
# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
|
||||
encoder_padding_mask = padding_mask[search_indices].flatten(0, 1)
|
||||
# fmt: on
|
||||
|
||||
num_searches = new_num_searches
|
||||
else:
|
||||
search_indices = None
|
||||
|
||||
eos_mask[:, :beam_size][ignored_beam_mask] = True
|
||||
|
||||
# Set `beam_weights` so that values greater than or equal to 2 x
|
||||
# `beam_size` indicate finished beams (i.e. end with EOS) and values
|
||||
# less than 2 x `beam_size` indicate active beams.
|
||||
# (N, 2 x B)
|
||||
beam_weights = cand_offsets + (eos_mask * (2 * beam_size))
|
||||
|
||||
# Get the top `beam_size` active beams, which are the beams with the
|
||||
# smallest weights in `active_beam_weights`.
|
||||
# (N, B)
|
||||
active_beam_weights, active_beams = torch.topk(
|
||||
beam_weights, k=beam_size, dim=1, largest=False
|
||||
)
|
||||
|
||||
# Update to ignore finalized beams in the next step.
|
||||
# (N, B)
|
||||
ignored_beam_mask = active_beam_weights >= 2 * beam_size
|
||||
|
||||
# We should always have at least one active beam in each search.
|
||||
assert (~ignored_beam_mask).any(dim=1).all()
|
||||
|
||||
# Denotes which beams are continued for each new hypothesis (a beam
|
||||
# can be selected more than once).
|
||||
# (N, B)
|
||||
beam_indices = torch.gather(
|
||||
global_cand_beam_indices, dim=1, index=active_beams
|
||||
)
|
||||
|
||||
# (N, B) -> (N x B)
|
||||
beam_indices = beam_indices.view(-1)
|
||||
|
||||
# fmt: off
|
||||
# Reorder beams in the `seq` and `score` buffers. The same beam can
|
||||
# be selected more than once.
|
||||
if step_nr > start_step:
|
||||
seqs [:, : step_nr + 1] = torch.index_select(
|
||||
seqs [:, : step_nr + 1], dim=0, index=beam_indices
|
||||
)
|
||||
scores[:, : step_nr + 1] = torch.index_select(
|
||||
scores[:, : step_nr + 1], dim=0, index=beam_indices
|
||||
)
|
||||
|
||||
# (N x B, S) -> (N, B, S)
|
||||
seqs_view = seqs .view(num_searches, beam_size, -1)
|
||||
scores_view = scores.view(num_searches, beam_size, -1)
|
||||
|
||||
seqs_view [:, :, step_nr + 1] = torch.gather(cand_indices, dim=1, index=active_beams)
|
||||
scores_view[:, :, step_nr + 1] = torch.gather(cand_scores, dim=1, index=active_beams)
|
||||
# fmt: on
|
||||
|
||||
# Ensure that hypotheses are sorted by their scores before returning.
|
||||
for batch in finished_searches:
|
||||
batch.sort(key=lambda b: b.score, reverse=True) # type: ignore[arg-type, return-value]
|
||||
|
||||
return SequenceGeneratorOutput(
|
||||
results=finished_searches, device=device, collater=self.collater
|
||||
)
|
||||
|
||||
def _determine_max_seq_len(self, source_seq_len: Optional[int]) -> int:
|
||||
opts = self.opts
|
||||
|
||||
if source_seq_len is None or opts.soft_max_seq_len is None:
|
||||
max_seq_len = opts.hard_max_seq_len
|
||||
else:
|
||||
at, bt = opts.soft_max_seq_len
|
||||
|
||||
max_seq_len = min(opts.hard_max_seq_len, int(at * source_seq_len + bt))
|
||||
|
||||
if opts.min_seq_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The effective maximum sequence length must be greater than or equal to `min_seq_len` ({opts.min_seq_len}), but is {max_seq_len} instead. Adjust your soft and hard maximum sequence length limits."
|
||||
)
|
||||
|
||||
if self.prefix_seq_len >= max_seq_len:
|
||||
raise ValueError(
|
||||
f"The effective maximum sequence length must be greater than `prefix_seq_len` ({self.prefix_seq_len}), but is {max_seq_len} instead."
|
||||
)
|
||||
|
||||
return max_seq_len
|
||||
|
||||
def _fan_out_encoder_output(
|
||||
self, encoder_output: Tensor, encoder_padding_mask: Optional[Tensor]
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
num_searches = encoder_output.size(0) # i.e. batch size
|
||||
|
||||
# Fan out `encoder_output` to `num_searches` x `beam_size`.
|
||||
# (N)
|
||||
fan_out_indices = torch.arange(num_searches, device=encoder_output.device)
|
||||
|
||||
# (N) -> (N x B)
|
||||
fan_out_indices = fan_out_indices.repeat_interleave(self.beam_size)
|
||||
|
||||
# (N, S_enc, M) -> (N x B, S_enc, M)
|
||||
encoder_output = encoder_output.index_select(dim=0, index=fan_out_indices)
|
||||
|
||||
# (N, S_enc, M) -> (N x B, S_enc, M)
|
||||
if encoder_padding_mask is not None:
|
||||
encoder_padding_mask = encoder_padding_mask.index_select(
|
||||
dim=0, index=fan_out_indices
|
||||
)
|
||||
|
||||
return encoder_output, encoder_padding_mask
|
||||
|
||||
def _bootstrap_seqs_and_scores(
|
||||
self,
|
||||
seqs: Tensor,
|
||||
scores: Tensor,
|
||||
encoder_output: Tensor,
|
||||
encoder_padding_mask: Optional[Tensor],
|
||||
state_bag: IncrementalStateBag,
|
||||
) -> None:
|
||||
assert self.prefix_seq_len > 0
|
||||
|
||||
seqs[:, : self.prefix_seq_len] = self.prefix_seq
|
||||
|
||||
if self.prefix_seq_len == 1:
|
||||
return
|
||||
|
||||
assert isinstance(self.prefix_seq, Tensor)
|
||||
|
||||
# We have to bootstrap the model with the already fanned-out encoder
|
||||
# output to correctly initialize its incremental state. This causes some
|
||||
# redundancy as we have to expand `decoder_input` to match the shape of
|
||||
# `encoder_output`.
|
||||
# (S_pfx) -> (N x B, S_pfx - 1)
|
||||
decoder_input = self.prefix_seq[:-1].expand(encoder_output.size(0), -1)
|
||||
|
||||
# Bootstrap the model state with prefix sequence.
|
||||
decoder_output, decoder_padding_mask = self.decoder.decode(
|
||||
decoder_input,
|
||||
None,
|
||||
encoder_output,
|
||||
encoder_padding_mask,
|
||||
state_bag,
|
||||
)
|
||||
|
||||
state_bag.increment_step(self.prefix_seq_len - 1)
|
||||
|
||||
model_output = self.decoder.project(decoder_output, decoder_padding_mask)
|
||||
|
||||
# lprobs: (S_pfx - 1, V)
|
||||
# model_output: (N, S_pfx - 1, V) -> (S_pfx - 1, V)
|
||||
lprobs = log_softmax(model_output.logits[0], dim=-1, dtype=torch.float32)
|
||||
|
||||
# Fetch scores of next steps.
|
||||
# (S_pfx - 1, 1)
|
||||
prefix_scores = torch.take_along_dim(
|
||||
lprobs, indices=self.prefix_seq[1:].unsqueeze(1), dim=-1
|
||||
)
|
||||
|
||||
# (S_pfx - 1, 1) -> (S_pfx - 1)
|
||||
prefix_scores.squeeze_(1).cumsum_(dim=0)
|
||||
|
||||
# First step (e.g. EOS)'s score is always 0.
|
||||
scores[:, 1 : self.prefix_seq_len] = prefix_scores
|
||||
|
||||
def _finalize_hypothesis(
|
||||
self,
|
||||
step_nr: int,
|
||||
eos_beam_indices: Tensor,
|
||||
eos_scores: Tensor,
|
||||
seqs: Tensor,
|
||||
scores: Tensor,
|
||||
active_searches: List[Tuple[int, List["Hypothesis"]]],
|
||||
finished_searches: List[List["Hypothesis"]],
|
||||
) -> List[int]:
|
||||
# fmt: off
|
||||
finalized_seqs = seqs .index_select(dim=0, index=eos_beam_indices)
|
||||
finalized_scores = scores.index_select(dim=0, index=eos_beam_indices)
|
||||
|
||||
finalized_seqs = finalized_seqs [:, : step_nr + 2]
|
||||
finalized_scores = finalized_scores[:, : step_nr + 2]
|
||||
|
||||
# Finalize beams.
|
||||
finalized_seqs [:, -1] = self.eos_idx
|
||||
finalized_scores[:, -1] = eos_scores
|
||||
# fmt: on
|
||||
|
||||
# Convert from cumulative to per-step scores.
|
||||
finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
|
||||
|
||||
# Skip first EOS since it is always 0 and skews normalization.
|
||||
if self.opts.normalize_scores:
|
||||
eos_scores /= (step_nr + 1) ** self.opts.len_penalty
|
||||
|
||||
# Holds the ids of finished searches.
|
||||
newly_finished: List[int] = []
|
||||
|
||||
active_search_indices = (eos_beam_indices // self.beam_size).tolist()
|
||||
|
||||
for beam_idx, search_idx in enumerate(active_search_indices):
|
||||
search_id, hypotheses = active_searches[search_idx]
|
||||
|
||||
# We might have more than one beam finalized in one step that would
|
||||
# potentially exceed `beam_size` hypotheses.
|
||||
if len(hypotheses) == self.beam_size:
|
||||
continue
|
||||
|
||||
hypotheses.append(
|
||||
Hypothesis(
|
||||
seq=finalized_seqs[beam_idx],
|
||||
score=eos_scores[beam_idx],
|
||||
step_scores=finalized_scores[beam_idx],
|
||||
)
|
||||
)
|
||||
|
||||
if len(hypotheses) == self.beam_size:
|
||||
# We have `beam_size` hypotheses for this particular search, so
|
||||
# we finish it now.
|
||||
newly_finished.append(search_idx)
|
||||
|
||||
finished_searches[search_id] = hypotheses
|
||||
|
||||
newly_finished.sort()
|
||||
|
||||
# Remove finished searches from the active list.
|
||||
for idx in reversed(newly_finished):
|
||||
del active_searches[idx]
|
||||
|
||||
return newly_finished
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGeneratorOutput:
|
||||
"""Holds the output of a sequence generator."""
|
||||
|
||||
results: List[List["Hypothesis"]]
|
||||
"""The list of hypothesis generated per search, ordered by score."""
|
||||
|
||||
device: Device
|
||||
"""The device on which generated sequences reside."""
|
||||
|
||||
collater: Optional[Collater] = None
|
||||
"""The collater to use in :meth:`collate`."""
|
||||
|
||||
def collate(
|
||||
self, hypo_idx: int = 0, skip_batch: bool = False
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Collate the generated sequences at index ``hypo_idx`` in each search
|
||||
result into a single tensor.
|
||||
|
||||
:param hypo_idx:
|
||||
The index of hypothesis to extract from each search result.
|
||||
:param skip_batch:
|
||||
If ``True``, if a search result has no hypothesis at index `hypo_idx`,
|
||||
it will be skipped instead of raising an error.
|
||||
|
||||
:returns:
|
||||
- The collated sequences. *Shape:* :math:`(N,S)`, where :math:`N` is
|
||||
the number of search results and :math:`S` is the sequence length.
|
||||
- An array where each element represents the length of the sequence at
|
||||
the same index in the first returned value. *Shape:* :math:`(N)`,
|
||||
where :math:`N` is the number of search results.
|
||||
"""
|
||||
if self.collater is None:
|
||||
raise RuntimeError("The output has no associated `Collater` instance.")
|
||||
|
||||
if not self.results and not skip_batch:
|
||||
raise ValueError("The output must contain at least one search result.")
|
||||
|
||||
seqs = []
|
||||
|
||||
for search_idx, result in enumerate(self.results):
|
||||
if hypo_idx >= len(result):
|
||||
if not skip_batch:
|
||||
raise ValueError(
|
||||
f"Each search result must have at least {hypo_idx + 1} hypotheses, but search {search_idx} has only {len(result)}."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
seqs.append(result[hypo_idx].seq)
|
||||
|
||||
if not seqs:
|
||||
# Return a zero-dimensional (not scalar!) tensor.
|
||||
return torch.empty((0,), device=self.device, dtype=torch.int64), None
|
||||
|
||||
output = cast(SequenceData, self.collater(seqs))
|
||||
|
||||
return output["seqs"], output["seq_lens"] if output["is_ragged"] else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
"""Represents a hypothesis produced by a sequence generator."""
|
||||
|
||||
seq: Tensor
|
||||
"""The generated sequence."""
|
||||
|
||||
score: Tensor
|
||||
"""The score of the hypothesis."""
|
||||
|
||||
step_scores: Tensor
|
||||
"""The score of each individual sequence step."""
|
@ -110,10 +110,8 @@ from fairseq2.data.text import (
|
||||
TextTokenizer,
|
||||
vocabulary_from_sentencepiece,
|
||||
)
|
||||
from tokenizer import CharTokenizer
|
||||
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from fairseq2.nn.embedding import Embedding
|
||||
from fairseq2.nn.projection import TiedProjection
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -603,7 +601,7 @@ def save_checkpoint(
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
text_tokenizer_encoder: CharTokenizer,
|
||||
text_tokenizer_encoder: SentencePieceEncoder,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
@ -661,10 +659,8 @@ def compute_loss(
|
||||
batch_idx_train = params.batch_idx_train
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in the text
|
||||
texts = [text.replace(" ", "") for text in texts]
|
||||
text_tokens_list = [torch.tensor([params.eos_idx] + text_tokenizer_encoder.encode(text) + [params.eos_idx]) for text in texts]
|
||||
texts = batch["supervisions"]["text"]
|
||||
text_tokens_list = [text_tokenizer_encoder(text) for text in texts]
|
||||
prev_outputs_tokens = _batch_tensors(
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=params.pad_idx
|
||||
)
|
||||
@ -710,7 +706,7 @@ def compute_loss(
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
text_tokenizer_encoder: CharTokenizer,
|
||||
text_tokenizer_encoder: SentencePieceEncoder,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
@ -746,7 +742,7 @@ def train_one_epoch(
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
text_tokenizer_encoder: CharTokenizer,
|
||||
text_tokenizer_encoder: SentencePieceEncoder,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
@ -921,6 +917,7 @@ def train_one_epoch(
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
@ -960,28 +957,22 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
model_name_or_card = "seamlessM4T_medium"
|
||||
tokenizer_file = "./seamlessm4t/tokens.txt"
|
||||
lang = "cmn"
|
||||
|
||||
# text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
# text_tokenizer_encoder = SentencePieceEncoder(
|
||||
# text_tokenizer.model,
|
||||
# prefix_tokens=["</s>", f"__{lang}__"],
|
||||
# suffix_tokens=["</s>"],
|
||||
# )
|
||||
# #params.eos_idx = text_tokenizer.model.eos_idx
|
||||
# params.pad_idx = text_tokenizer.model.pad_idx
|
||||
text_tokenizer_encoder = CharTokenizer(tokenizer_file)
|
||||
params.pad_idx, params.eos_idx = 0, 1
|
||||
logging.info(params)
|
||||
|
||||
model = load_unity_model(model_name_or_card, device="cpu", dtype=torch.float32)
|
||||
del model.t2u_model
|
||||
del model.text_encoder
|
||||
del model.text_encoder_frontend
|
||||
model.text_decoder_frontend.embed = Embedding(num_embeddings=text_tokenizer_encoder.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
||||
#model.final_proj = TiedProjection(input_dim=1024, output_dim=text_tokenizer_encoder.vocab_size)
|
||||
model.final_proj = nn.Linear(1024, text_tokenizer_encoder.vocab_size)
|
||||
# print(vars(model))
|
||||
# exit(0)
|
||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
text_tokenizer_encoder = SentencePieceEncoder(
|
||||
text_tokenizer.model,
|
||||
prefix_tokens=["</s>", f"__{lang}__"],
|
||||
suffix_tokens=["</s>"],
|
||||
)
|
||||
#params.eos_idx = text_tokenizer.model.eos_idx
|
||||
params.pad_idx = text_tokenizer.model.pad_idx
|
||||
logging.info(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -1206,7 +1197,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
params: AttributeDict,
|
||||
text_tokenizer_encoder: CharTokenizer,
|
||||
text_tokenizer_encoder: SentencePieceEncoder,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
|
@ -110,8 +110,10 @@ from fairseq2.data.text import (
|
||||
TextTokenizer,
|
||||
vocabulary_from_sentencepiece,
|
||||
)
|
||||
|
||||
from tokenizer import CharTokenizer
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from fairseq2.nn.embedding import Embedding
|
||||
from fairseq2.nn.projection import TiedProjection
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -601,7 +603,7 @@ def save_checkpoint(
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
text_tokenizer_encoder: SentencePieceEncoder,
|
||||
text_tokenizer_encoder: CharTokenizer,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
@ -659,8 +661,10 @@ def compute_loss(
|
||||
batch_idx_train = params.batch_idx_train
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
text_tokens_list = [text_tokenizer_encoder(text) for text in texts]
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in the text
|
||||
texts = [text.replace(" ", "") for text in texts]
|
||||
text_tokens_list = [torch.tensor([params.eos_idx] + text_tokenizer_encoder.encode(text) + [params.eos_idx]) for text in texts]
|
||||
prev_outputs_tokens = _batch_tensors(
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=params.pad_idx
|
||||
)
|
||||
@ -706,7 +710,7 @@ def compute_loss(
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
text_tokenizer_encoder: SentencePieceEncoder,
|
||||
text_tokenizer_encoder: CharTokenizer,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
@ -742,7 +746,7 @@ def train_one_epoch(
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
text_tokenizer_encoder: SentencePieceEncoder,
|
||||
text_tokenizer_encoder: CharTokenizer,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
@ -917,7 +921,6 @@ def train_one_epoch(
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
@ -957,22 +960,42 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
model_name_or_card = "seamlessM4T_medium"
|
||||
tokenizer_file = "./seamlessm4t/tokens.txt"
|
||||
lang = "cmn"
|
||||
|
||||
# text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
# text_tokenizer_encoder = SentencePieceEncoder(
|
||||
# text_tokenizer.model,
|
||||
# prefix_tokens=["</s>", f"__{lang}__"],
|
||||
# suffix_tokens=["</s>"],
|
||||
# )
|
||||
# #params.eos_idx = text_tokenizer.model.eos_idx
|
||||
# params.pad_idx = text_tokenizer.model.pad_idx
|
||||
text_tokenizer_encoder = CharTokenizer(tokenizer_file)
|
||||
params.pad_idx, params.eos_idx = 0, 1
|
||||
logging.info(params)
|
||||
|
||||
model = load_unity_model(model_name_or_card, device="cpu", dtype=torch.float32)
|
||||
del model.t2u_model
|
||||
del model.text_encoder
|
||||
del model.text_encoder_frontend
|
||||
# print(vars(model))
|
||||
# exit(0)
|
||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
text_tokenizer_encoder = SentencePieceEncoder(
|
||||
text_tokenizer.model,
|
||||
prefix_tokens=["</s>", f"__{lang}__"],
|
||||
suffix_tokens=["</s>"],
|
||||
)
|
||||
#params.eos_idx = text_tokenizer.model.eos_idx
|
||||
params.pad_idx = text_tokenizer.model.pad_idx
|
||||
logging.info(params)
|
||||
model.text_decoder_frontend.embed = nn.Embedding(num_embeddings=text_tokenizer_encoder.vocab_size, embedding_dim=1024 ,padding_idx=0)
|
||||
#model.text_decoder_frontend.embed = Embedding(num_embeddings=text_tokenizer_encoder.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
||||
#model.final_proj = TiedProjection(input_dim=1024, output_dim=text_tokenizer_encoder.vocab_size)
|
||||
model.final_proj = nn.Linear(1024, text_tokenizer_encoder.vocab_size, bias=False)
|
||||
for name, param in model.named_parameters():
|
||||
if name != 'text_decoder_frontend.embed.weight' and name != 'final_proj.weight':
|
||||
#param.requires_grad = False
|
||||
pass
|
||||
model.text_decoder_frontend.embed.requires_grad = True
|
||||
model.final_proj.requires_grad = True
|
||||
print(model.text_decoder_frontend.embed.requires_grad, model.final_proj.requires_grad)
|
||||
for param in model.parameters():
|
||||
if param.requires_grad:
|
||||
print(233333333333333333333333333333333333333333333333333333333333333333333)
|
||||
for name, param in model.named_parameters():
|
||||
print(name, param.requires_grad)
|
||||
#exit(0)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -1197,7 +1220,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
params: AttributeDict,
|
||||
text_tokenizer_encoder: SentencePieceEncoder,
|
||||
text_tokenizer_encoder: CharTokenizer,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user