mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
remove seamless for next PR
This commit is contained in:
parent
ac53222054
commit
e883bb60d4
@ -1,9 +0,0 @@
|
|||||||
|
|
||||||
#export CUDA_VISIBLE_DEVICES="2,3"
|
|
||||||
#pip install -r seamlessm4t/requirements.txt
|
|
||||||
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
|
|
||||||
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
|
|
||||||
python3 seamlessm4t/decode.py --epoch 5 --exp-dir seamlessm4t/exp
|
|
||||||
python3 seamlessm4t/decode.py --epoch 5 --avg 2 --exp-dir seamlessm4t/exp
|
|
@ -1,8 +0,0 @@
|
|||||||
|
|
||||||
#export CUDA_VISIBLE_DEVICES="1"
|
|
||||||
#pip install -r whisper/requirements.txt
|
|
||||||
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
|
||||||
#export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall/
|
|
||||||
|
|
||||||
python3 whisper/decode.py --exp-dir whisper/exp --max-duration 100
|
|
@ -1,8 +0,0 @@
|
|||||||
|
|
||||||
#export CUDA_VISIBLE_DEVICES="2,3"
|
|
||||||
pip install -r seamlessm4t/requirements.txt
|
|
||||||
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
|
|
||||||
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
|
|
||||||
torchrun --nproc-per-node 8 seamlessm4t/train.py --use-fp16 1 --max-duration 300 --base-lr 1e-5 --exp-dir seamlessm4t/exp_new_vocab --start-epoch 1
|
|
@ -1,9 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
|
||||||
pip install -r whisper/requirements.txt
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
|
|
||||||
#export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
|
||||||
#export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall
|
|
||||||
|
|
||||||
torchrun --nproc-per-node 8 whisper/train.py --use-fp16 1 --max-duration 20 --base-lr 1e-5 --exp-dir whisper/exp_medimum --start-epoch 1
|
|
@ -1 +0,0 @@
|
|||||||
../tdnn_lstm_ctc/asr_datamodule.py
|
|
@ -1,415 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
|
||||||
# Fangjun Kuang,
|
|
||||||
# Wei Kang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from asr_datamodule import AishellAsrDataModule
|
|
||||||
#from conformer import Conformer
|
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
|
|
||||||
from icefall.decode import (
|
|
||||||
get_lattice,
|
|
||||||
nbest_decoding,
|
|
||||||
nbest_oracle,
|
|
||||||
one_best_decoding,
|
|
||||||
rescore_with_attention_decoder,
|
|
||||||
)
|
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
get_texts,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
from seamless_communication.models.unity import (
|
|
||||||
UnitYModel,
|
|
||||||
load_unity_model,
|
|
||||||
load_unity_text_tokenizer,
|
|
||||||
)
|
|
||||||
from fairseq2.generation import (
|
|
||||||
SequenceGeneratorOptions,
|
|
||||||
SequenceToTextGenerator,
|
|
||||||
)
|
|
||||||
from seamless_communication.models.unity.model import UnitYX2TModel
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="It specifies the checkpoint to use for decoding."
|
|
||||||
"Note: Epoch counts from 0.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch'. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--method",
|
|
||||||
type=str,
|
|
||||||
default="beam-search",
|
|
||||||
help="""Decoding method.
|
|
||||||
Supported values are:
|
|
||||||
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
|
|
||||||
tokens using token symbol tabel directly.
|
|
||||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
|
||||||
decoding result.
|
|
||||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
|
||||||
with the highest score is the decoding result.
|
|
||||||
- (3) attention-decoder. Extract n paths from the lattice,
|
|
||||||
the path with the highest score is the decoding result.
|
|
||||||
- (4) nbest-oracle. Its WER is the lower bound of any n-best
|
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
|
||||||
rescoring method.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="seamlessm4t/exp",
|
|
||||||
help="The experiment dir",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# parameters for conformer
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"feature_dim": 80,
|
|
||||||
"nhead": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"num_decoder_layers": 6,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"search_beam": 20,
|
|
||||||
"output_beam": 7,
|
|
||||||
"min_active_states": 30,
|
|
||||||
"max_active_states": 10000,
|
|
||||||
"use_double_scores": True,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
|
||||||
params: AttributeDict,
|
|
||||||
s2t_generator: SequenceToTextGenerator,
|
|
||||||
batch: dict,
|
|
||||||
) -> Dict[str, List[List[int]]]:
|
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
|
||||||
following format:
|
|
||||||
|
|
||||||
- key: It indicates the setting used for decoding. For example,
|
|
||||||
if decoding method is 1best, the key is the string `no_rescore`.
|
|
||||||
If attention rescoring is used, the key is the string
|
|
||||||
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
|
|
||||||
value of `lm_scale` and `attention_scale`. An example key is
|
|
||||||
`ngram_lm_scale_0.7_attention_scale_0.5`
|
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
|
||||||
utterance in the given batch.
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
It's the return value of :func:`get_params`.
|
|
||||||
|
|
||||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
|
||||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
|
||||||
- params.method is "attention-decoder", it uses attention rescoring.
|
|
||||||
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
HLG:
|
|
||||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
|
||||||
H:
|
|
||||||
The ctc topo. Used only when params.method is ctc-decoding.
|
|
||||||
batch:
|
|
||||||
It is the return value from iterating
|
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
|
||||||
for the format of the `batch`.
|
|
||||||
lexicon:
|
|
||||||
It contains the token symbol table and the word symbol table.
|
|
||||||
sos_id:
|
|
||||||
The token ID of the SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID of the EOS.
|
|
||||||
Returns:
|
|
||||||
Return the decoding result. See above description for the format of
|
|
||||||
the returned dict.
|
|
||||||
"""
|
|
||||||
dtype = torch.float16
|
|
||||||
device = torch.device("cuda", 3)
|
|
||||||
|
|
||||||
feature = batch["inputs"]
|
|
||||||
assert feature.ndim == 3
|
|
||||||
feature = feature.to(device, dtype=dtype)
|
|
||||||
# at entry, feature is (N, T, C)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
feature_len = supervisions["num_frames"]
|
|
||||||
feature_len = feature_len.to(device, dtype=dtype)
|
|
||||||
|
|
||||||
text_output = s2t_generator.generate_ex(feature, feature_len)
|
|
||||||
sentences = text_output.sentences
|
|
||||||
hyps = [sentence.bytes().decode("utf-8").split() for sentence in sentences]
|
|
||||||
key = "beam-search"
|
|
||||||
|
|
||||||
return {key: hyps}
|
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
|
||||||
dl: torch.utils.data.DataLoader,
|
|
||||||
params: AttributeDict,
|
|
||||||
s2t_generator: SequenceToTextGenerator,
|
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
|
||||||
"""Decode dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dl:
|
|
||||||
PyTorch's dataloader containing the dataset to decode.
|
|
||||||
params:
|
|
||||||
It is returned by :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
HLG:
|
|
||||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
|
||||||
H:
|
|
||||||
The ctc topo. Used only when params.method is ctc-decoding.
|
|
||||||
lexicon:
|
|
||||||
It contains the token symbol table and the word symbol table.
|
|
||||||
sos_id:
|
|
||||||
The token ID for SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID for EOS.
|
|
||||||
Returns:
|
|
||||||
Return a dict, whose key may be "no-rescore" if the decoding method is
|
|
||||||
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
|
|
||||||
rescoring is used. Its value is a list of tuples. Each tuple contains two
|
|
||||||
elements: The first is the reference transcript, and the second is the
|
|
||||||
predicted result.
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
|
|
||||||
num_cuts = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
num_batches = len(dl)
|
|
||||||
except TypeError:
|
|
||||||
num_batches = "?"
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
|
||||||
for batch_idx, batch in enumerate(dl):
|
|
||||||
texts = batch["supervisions"]["text"]
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
|
||||||
params=params,
|
|
||||||
s2t_generator=s2t_generator,
|
|
||||||
batch=batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
|
||||||
this_batch = []
|
|
||||||
assert len(hyps) == len(texts)
|
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
|
||||||
ref_words = ref_text.split()
|
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
|
||||||
|
|
||||||
num_cuts += len(batch["supervisions"]["text"])
|
|
||||||
|
|
||||||
if batch_idx % 100 == 0:
|
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
|
||||||
params: AttributeDict,
|
|
||||||
test_set_name: str,
|
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
|
||||||
):
|
|
||||||
|
|
||||||
enable_log = True
|
|
||||||
test_set_wers = dict()
|
|
||||||
for key, results in results_dict.items():
|
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
results = sorted(results)
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
|
||||||
if enable_log:
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
|
||||||
# ref/hyp pairs.
|
|
||||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
# we compute CER for aishell dataset.
|
|
||||||
results_char = []
|
|
||||||
for res in results:
|
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
|
||||||
with open(errs_filename, "w") as f:
|
|
||||||
wer = write_error_stats(
|
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
|
||||||
)
|
|
||||||
test_set_wers[key] = wer
|
|
||||||
|
|
||||||
if enable_log:
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
||||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
with open(errs_info, "w") as f:
|
|
||||||
print("settings\tCER", file=f)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
print("{}\t{}".format(key, val), file=f)
|
|
||||||
|
|
||||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
|
||||||
note = "\tbest for {}".format(test_set_name)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
|
||||||
note = ""
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
AishellAsrDataModule.add_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
params.update(vars(args))
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
|
||||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
|
||||||
logging.info("Decoding started")
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 3)
|
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
model_name_or_card = "seamlessM4T_medium"
|
|
||||||
#model_name_or_card = "seamlessM4T_large"
|
|
||||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
|
||||||
del model.t2u_model
|
|
||||||
del model.text_encoder
|
|
||||||
del model.text_encoder_frontend
|
|
||||||
if params.epoch > 0:
|
|
||||||
if params.avg > 1:
|
|
||||||
start = params.epoch - params.avg
|
|
||||||
assert start >= 1, start
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
|
||||||
|
|
||||||
text_max_len_a = 1
|
|
||||||
text_max_len_b = 200
|
|
||||||
target_lang = "cmn"
|
|
||||||
|
|
||||||
text_opts = SequenceGeneratorOptions(
|
|
||||||
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
|
||||||
)
|
|
||||||
|
|
||||||
s2t_model = UnitYX2TModel(
|
|
||||||
encoder_frontend=model.speech_encoder_frontend,
|
|
||||||
encoder=model.speech_encoder,
|
|
||||||
decoder_frontend=model.text_decoder_frontend,
|
|
||||||
decoder=model.text_decoder,
|
|
||||||
final_proj=model.final_proj,
|
|
||||||
pad_idx=model.pad_idx,
|
|
||||||
)
|
|
||||||
s2t_generator = SequenceToTextGenerator(
|
|
||||||
s2t_model, text_tokenizer, target_lang, text_opts
|
|
||||||
)
|
|
||||||
# we need cut ids to display recognition results.
|
|
||||||
args.return_cuts = True
|
|
||||||
aishell = AishellAsrDataModule(args)
|
|
||||||
test_cuts = aishell.test_cuts()
|
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test"]
|
|
||||||
test_dls = [test_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
|
||||||
results_dict = decode_dataset(
|
|
||||||
dl=test_dl,
|
|
||||||
params=params,
|
|
||||||
s2t_generator=s2t_generator,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
|
||||||
|
|
||||||
logging.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,432 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
|
||||||
# Fangjun Kuang,
|
|
||||||
# Wei Kang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from asr_datamodule import AishellAsrDataModule
|
|
||||||
#from conformer import Conformer
|
|
||||||
from tokenizer import CharTokenizer
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
|
|
||||||
from icefall.decode import (
|
|
||||||
get_lattice,
|
|
||||||
nbest_decoding,
|
|
||||||
nbest_oracle,
|
|
||||||
one_best_decoding,
|
|
||||||
rescore_with_attention_decoder,
|
|
||||||
)
|
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
get_texts,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
from seamless_communication.models.unity import (
|
|
||||||
UnitYModel,
|
|
||||||
load_unity_model,
|
|
||||||
load_unity_text_tokenizer,
|
|
||||||
)
|
|
||||||
from fairseq2.generation import (
|
|
||||||
SequenceGeneratorOptions,
|
|
||||||
SequenceToTextGenerator,
|
|
||||||
)
|
|
||||||
from seamless_communication.models.unity.model import UnitYX2TModel
|
|
||||||
from fairseq2.nn.embedding import Embedding
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="It specifies the checkpoint to use for decoding."
|
|
||||||
"Note: Epoch counts from 0.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch'. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--method",
|
|
||||||
type=str,
|
|
||||||
default="beam-search",
|
|
||||||
help="""Decoding method.
|
|
||||||
Supported values are:
|
|
||||||
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
|
|
||||||
tokens using token symbol tabel directly.
|
|
||||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
|
||||||
decoding result.
|
|
||||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
|
||||||
with the highest score is the decoding result.
|
|
||||||
- (3) attention-decoder. Extract n paths from the lattice,
|
|
||||||
the path with the highest score is the decoding result.
|
|
||||||
- (4) nbest-oracle. Its WER is the lower bound of any n-best
|
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
|
||||||
rescoring method.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="seamlessm4t/exp",
|
|
||||||
help="The experiment dir",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# parameters for conformer
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"feature_dim": 80,
|
|
||||||
"nhead": 4,
|
|
||||||
"attention_dim": 512,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
"num_decoder_layers": 6,
|
|
||||||
"vgg_frontend": False,
|
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"search_beam": 20,
|
|
||||||
"output_beam": 7,
|
|
||||||
"min_active_states": 30,
|
|
||||||
"max_active_states": 10000,
|
|
||||||
"use_double_scores": True,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
|
||||||
params: AttributeDict,
|
|
||||||
s2t_generator: SequenceToTextGenerator,
|
|
||||||
batch: dict,
|
|
||||||
) -> Dict[str, List[List[int]]]:
|
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
|
||||||
following format:
|
|
||||||
|
|
||||||
- key: It indicates the setting used for decoding. For example,
|
|
||||||
if decoding method is 1best, the key is the string `no_rescore`.
|
|
||||||
If attention rescoring is used, the key is the string
|
|
||||||
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
|
|
||||||
value of `lm_scale` and `attention_scale`. An example key is
|
|
||||||
`ngram_lm_scale_0.7_attention_scale_0.5`
|
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
|
||||||
utterance in the given batch.
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
It's the return value of :func:`get_params`.
|
|
||||||
|
|
||||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
|
||||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
|
||||||
- params.method is "attention-decoder", it uses attention rescoring.
|
|
||||||
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
HLG:
|
|
||||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
|
||||||
H:
|
|
||||||
The ctc topo. Used only when params.method is ctc-decoding.
|
|
||||||
batch:
|
|
||||||
It is the return value from iterating
|
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
|
||||||
for the format of the `batch`.
|
|
||||||
lexicon:
|
|
||||||
It contains the token symbol table and the word symbol table.
|
|
||||||
sos_id:
|
|
||||||
The token ID of the SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID of the EOS.
|
|
||||||
Returns:
|
|
||||||
Return the decoding result. See above description for the format of
|
|
||||||
the returned dict.
|
|
||||||
"""
|
|
||||||
dtype = torch.float16
|
|
||||||
device = torch.device("cuda", 3)
|
|
||||||
|
|
||||||
feature = batch["inputs"]
|
|
||||||
assert feature.ndim == 3
|
|
||||||
feature = feature.to(device, dtype=dtype)
|
|
||||||
# at entry, feature is (N, T, C)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
feature_len = supervisions["num_frames"]
|
|
||||||
feature_len = feature_len.to(device, dtype=dtype)
|
|
||||||
|
|
||||||
text_output = s2t_generator.generate_ex(feature, feature_len)
|
|
||||||
#sentences = text_output.sentences
|
|
||||||
#hyps = [sentence.bytes().decode("utf-8").split() for sentence in sentences]
|
|
||||||
|
|
||||||
token_ids = text_output.generator_output.results
|
|
||||||
hyps_ids = [sentence[0].seq.cpu().tolist() for sentence in token_ids]
|
|
||||||
hyps = [params.tokenizer.decode(hyps_id).split() for hyps_id in hyps_ids]
|
|
||||||
|
|
||||||
key = "beam-search"
|
|
||||||
|
|
||||||
return {key: hyps}
|
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
|
||||||
dl: torch.utils.data.DataLoader,
|
|
||||||
params: AttributeDict,
|
|
||||||
s2t_generator: SequenceToTextGenerator,
|
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
|
||||||
"""Decode dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dl:
|
|
||||||
PyTorch's dataloader containing the dataset to decode.
|
|
||||||
params:
|
|
||||||
It is returned by :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
HLG:
|
|
||||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
|
||||||
H:
|
|
||||||
The ctc topo. Used only when params.method is ctc-decoding.
|
|
||||||
lexicon:
|
|
||||||
It contains the token symbol table and the word symbol table.
|
|
||||||
sos_id:
|
|
||||||
The token ID for SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID for EOS.
|
|
||||||
Returns:
|
|
||||||
Return a dict, whose key may be "no-rescore" if the decoding method is
|
|
||||||
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
|
|
||||||
rescoring is used. Its value is a list of tuples. Each tuple contains two
|
|
||||||
elements: The first is the reference transcript, and the second is the
|
|
||||||
predicted result.
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
|
|
||||||
num_cuts = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
num_batches = len(dl)
|
|
||||||
except TypeError:
|
|
||||||
num_batches = "?"
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
|
||||||
for batch_idx, batch in enumerate(dl):
|
|
||||||
texts = batch["supervisions"]["text"]
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
|
||||||
params=params,
|
|
||||||
s2t_generator=s2t_generator,
|
|
||||||
batch=batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
|
||||||
this_batch = []
|
|
||||||
assert len(hyps) == len(texts)
|
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
|
||||||
ref_words = ref_text.split()
|
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
|
||||||
|
|
||||||
num_cuts += len(batch["supervisions"]["text"])
|
|
||||||
|
|
||||||
if batch_idx % 100 == 0:
|
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
|
||||||
params: AttributeDict,
|
|
||||||
test_set_name: str,
|
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
|
||||||
):
|
|
||||||
|
|
||||||
enable_log = True
|
|
||||||
test_set_wers = dict()
|
|
||||||
for key, results in results_dict.items():
|
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
results = sorted(results)
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
|
||||||
if enable_log:
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
|
||||||
# ref/hyp pairs.
|
|
||||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
# we compute CER for aishell dataset.
|
|
||||||
results_char = []
|
|
||||||
for res in results:
|
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
|
||||||
with open(errs_filename, "w") as f:
|
|
||||||
wer = write_error_stats(
|
|
||||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
|
||||||
)
|
|
||||||
test_set_wers[key] = wer
|
|
||||||
|
|
||||||
if enable_log:
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
||||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
with open(errs_info, "w") as f:
|
|
||||||
print("settings\tCER", file=f)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
print("{}\t{}".format(key, val), file=f)
|
|
||||||
|
|
||||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
|
||||||
note = "\tbest for {}".format(test_set_name)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
|
||||||
note = ""
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
AishellAsrDataModule.add_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
params.tokenizer = CharTokenizer('./seamlessm4t/tokens.txt')
|
|
||||||
params.update(vars(args))
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
|
||||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
|
||||||
logging.info("Decoding started")
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 3)
|
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
model_name_or_card = "seamlessM4T_medium"
|
|
||||||
#model_name_or_card = "seamlessM4T_large"
|
|
||||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
|
||||||
del model.t2u_model
|
|
||||||
del model.text_encoder
|
|
||||||
del model.text_encoder_frontend
|
|
||||||
model.text_decoder_frontend.embed = nn.Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,padding_idx=0)
|
|
||||||
#model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
|
||||||
model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size, bias=False)
|
|
||||||
#model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size)
|
|
||||||
if params.epoch > 0:
|
|
||||||
if params.avg > 1:
|
|
||||||
start = params.epoch - params.avg
|
|
||||||
assert start >= 1, start
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
model.half()
|
|
||||||
#for param in model.parameters():
|
|
||||||
# if param.dtype == torch.float16:
|
|
||||||
# pass
|
|
||||||
# else:
|
|
||||||
# param.data = param.data.to(torch.float16)
|
|
||||||
#print(param)
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
|
||||||
|
|
||||||
text_max_len_a = 1
|
|
||||||
text_max_len_b = 200
|
|
||||||
target_lang = "cmn"
|
|
||||||
|
|
||||||
text_opts = SequenceGeneratorOptions(
|
|
||||||
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
|
||||||
)
|
|
||||||
|
|
||||||
s2t_model = UnitYX2TModel(
|
|
||||||
encoder_frontend=model.speech_encoder_frontend,
|
|
||||||
encoder=model.speech_encoder,
|
|
||||||
decoder_frontend=model.text_decoder_frontend,
|
|
||||||
decoder=model.text_decoder,
|
|
||||||
final_proj=model.final_proj,
|
|
||||||
pad_idx=model.pad_idx,
|
|
||||||
)
|
|
||||||
s2t_generator = SequenceToTextGenerator(
|
|
||||||
s2t_model, text_tokenizer, target_lang, text_opts
|
|
||||||
)
|
|
||||||
# we need cut ids to display recognition results.
|
|
||||||
args.return_cuts = True
|
|
||||||
aishell = AishellAsrDataModule(args)
|
|
||||||
test_cuts = aishell.test_cuts()
|
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test"]
|
|
||||||
test_dls = [test_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
|
||||||
results_dict = decode_dataset(
|
|
||||||
dl=test_dl,
|
|
||||||
params=params,
|
|
||||||
s2t_generator=s2t_generator,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
|
||||||
|
|
||||||
logging.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
|
@ -1,133 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from fairseq2.nn.embedding import Embedding
|
|
||||||
from seamless_communication.models.inference import Translator
|
|
||||||
from seamless_communication.models.unity import (
|
|
||||||
UnitTokenizer,
|
|
||||||
UnitYModel,
|
|
||||||
load_unity_model,
|
|
||||||
load_unity_text_tokenizer,
|
|
||||||
load_unity_unit_tokenizer,
|
|
||||||
)
|
|
||||||
from fairseq2.generation import (
|
|
||||||
Seq2SeqGenerator,
|
|
||||||
SequenceGeneratorOptions,
|
|
||||||
SequenceGeneratorOutput,
|
|
||||||
SequenceToTextGenerator,
|
|
||||||
SequenceToTextOutput,
|
|
||||||
)
|
|
||||||
from seamless_communication.models.unity.model import UnitYModel, UnitYX2TModel
|
|
||||||
|
|
||||||
import torchaudio
|
|
||||||
import torchaudio.compliance.kaldi as ta_kaldi
|
|
||||||
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_en/wav/1089-134686-0001.wav"
|
|
||||||
src_lang="cmn"
|
|
||||||
|
|
||||||
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_zh/wav/long.wav"
|
|
||||||
src_lang="eng"
|
|
||||||
target_lang = "cmn"
|
|
||||||
|
|
||||||
audio_input = torchaudio.load(audio_file)[0]
|
|
||||||
feature = ta_kaldi.fbank(audio_input, num_mel_bins=80)
|
|
||||||
# feature shape is (T, F), convert it to (B, T, F), source_seq_lens tracks T
|
|
||||||
source_seqs = feature.unsqueeze(0)
|
|
||||||
source_seq_lens = torch.tensor([feature.shape[0]])
|
|
||||||
|
|
||||||
# Initialize a Translator object with a multitask model, vocoder on the GPU.
|
|
||||||
|
|
||||||
|
|
||||||
# translator = Translator("seamlessM4T_medium", vocoder_name_or_card="vocoder_36langs", device=torch.device("cuda:2"), dtype=torch.float16)
|
|
||||||
|
|
||||||
# transcribed_text, _, _ = translator.predict(audio_file, "asr", src_lang)
|
|
||||||
|
|
||||||
# print(transcribed_text)
|
|
||||||
|
|
||||||
|
|
||||||
model_name_or_card = "seamlessM4T_medium"
|
|
||||||
device = torch.device("cuda:3")
|
|
||||||
|
|
||||||
# cast source_seq_lens, source_seqs to device, dtype to torch.float16
|
|
||||||
source_seq_lens = source_seq_lens.to(device=device, dtype=torch.float16)
|
|
||||||
source_seqs = source_seqs.to(device=device, dtype=torch.float16)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
dtype = torch.float16
|
|
||||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
|
||||||
model.eval()
|
|
||||||
model.text_decoder_frontend.embed = Embedding(num_embeddings=6257, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
|
||||||
model.final_proj = nn.Linear(1024, 6257)
|
|
||||||
model.half()
|
|
||||||
print(model.text_decoder_frontend.embed, model.text_encoder_frontend.embed.weight.dtype, type(model.text_encoder_frontend.embed), type(model.text_encoder_frontend.embed.weight))
|
|
||||||
print(model.final_proj, model.final_proj.weight.dtype, type(model.final_proj), type(model.final_proj.weight))
|
|
||||||
#input()
|
|
||||||
exit(0)
|
|
||||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
|
||||||
#print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx)
|
|
||||||
#text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target")
|
|
||||||
#text_tokenizer_decoder = text_tokenizer.create_decoder()
|
|
||||||
# print attritbut of text_tokenizer_encoder
|
|
||||||
#print(text_tokenizer.vocab_info)
|
|
||||||
#print(text_tokenizer_encoder("其中广州深圳甚至出现了多个日光盘"))
|
|
||||||
#print(text_tokenizer_decoder(torch.tensor([3,256200,137139,252603,250476,250590,1,84778,148897,249568,249352,249947,249050,250520,254508])))
|
|
||||||
|
|
||||||
# store all vocab in a file
|
|
||||||
# with open("vocab.txt", "w") as f:
|
|
||||||
# for i in range(256206):
|
|
||||||
# f.write(f"{i}: " + text_tokenizer_decoder(torch.tensor([i]))[0].bytes().decode("utf-8")+ "\n")
|
|
||||||
# f.close()
|
|
||||||
# exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def decode(
|
|
||||||
# self,
|
|
||||||
# seqs: Tensor,
|
|
||||||
# seq_lens: Optional[Tensor],
|
|
||||||
# encoder_output: Tensor,
|
|
||||||
# encoder_padding_mask: Optional[Tensor],
|
|
||||||
# state_bag: Optional[IncrementalStateBag] = None,
|
|
||||||
# ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
||||||
# seqs, padding_mask = self.text_decoder_frontend(seqs, seq_lens, state_bag)
|
|
||||||
|
|
||||||
# return self.text_decoder( # type: ignore[no-any-return]
|
|
||||||
# seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def decoding(model, feature):
|
|
||||||
# seqs, padding_mask = model.speech_encoder_frontend(seqs, seq_lens)
|
|
||||||
# speech_encoder(seqs, padding_mask)
|
|
||||||
|
|
||||||
# decoder_output, decoder_padding_mask = self.decode(
|
|
||||||
# batch.target_seqs,
|
|
||||||
# batch.target_seq_lens,
|
|
||||||
# encoder_output,
|
|
||||||
# encoder_padding_mask,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# text_logits = model.final_project(decoder_output, decoder_padding_mask)
|
|
||||||
|
|
||||||
text_max_len_a = 1
|
|
||||||
text_max_len_b = 200
|
|
||||||
|
|
||||||
text_opts = SequenceGeneratorOptions(
|
|
||||||
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
|
||||||
)
|
|
||||||
|
|
||||||
s2t_model = UnitYX2TModel(
|
|
||||||
encoder_frontend=model.speech_encoder_frontend,
|
|
||||||
encoder=model.speech_encoder,
|
|
||||||
decoder_frontend=model.text_decoder_frontend,
|
|
||||||
decoder=model.text_decoder,
|
|
||||||
final_proj=model.final_proj,
|
|
||||||
pad_idx=model.pad_idx,
|
|
||||||
)
|
|
||||||
s2t_generator = SequenceToTextGenerator(
|
|
||||||
s2t_model, text_tokenizer, target_lang, text_opts
|
|
||||||
)
|
|
||||||
|
|
||||||
text_output = s2t_generator.generate_ex(source_seqs, source_seq_lens)
|
|
||||||
print(text_output.generator_output.results[0][0].seq.cpu().tolist())
|
|
||||||
# sentence = text_output.sentences[0]
|
|
||||||
# print(sentence, type(sentence))
|
|
||||||
# sentence = sentence.bytes().decode("utf-8")
|
|
File diff suppressed because it is too large
Load Diff
@ -1,694 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional, Tuple, Union, cast
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.nn.functional import log_softmax
|
|
||||||
|
|
||||||
from fairseq2.data import Collater, SequenceData, VocabularyInfo
|
|
||||||
from fairseq2.generation.beam_search import BeamSearch, StandardBeamSearch
|
|
||||||
from fairseq2.generation.logits_processor import LogitsProcessor
|
|
||||||
from fairseq2.models.encoder_decoder import Seq2SeqDecoder
|
|
||||||
from fairseq2.nn.incremental_state import IncrementalStateBag
|
|
||||||
from fairseq2.typing import Device
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SequenceGeneratorOptions:
|
|
||||||
"""Holds the options to pass to a sequence generator."""
|
|
||||||
|
|
||||||
beam_size: int = 5
|
|
||||||
"""The beam size."""
|
|
||||||
|
|
||||||
min_seq_len: int = 1
|
|
||||||
"""The minimum length of generated sequences (including prefix sequence)."""
|
|
||||||
|
|
||||||
soft_max_seq_len: Optional[Tuple[int, int]] = (1, 200)
|
|
||||||
"""The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source
|
|
||||||
sequence length. The generated sequences (including prefix sequence) will
|
|
||||||
have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also
|
|
||||||
``hard_max_seq_len``."""
|
|
||||||
|
|
||||||
hard_max_seq_len: int = 1024
|
|
||||||
"""The hard limit on maximum length of generated sequences."""
|
|
||||||
|
|
||||||
len_penalty: float = 1.0
|
|
||||||
"""The length penalty, where values less than 1.0 favor shorter, values
|
|
||||||
greater than 1.0 favor longer sequences."""
|
|
||||||
|
|
||||||
unk_penalty: float = 0.0
|
|
||||||
"""The unknown symbol penalty, where values less than 0 produce more UNKs,
|
|
||||||
values greater than 0 produce fewer UNKs."""
|
|
||||||
|
|
||||||
normalize_scores: bool = True
|
|
||||||
"""If ``True``, normalizes scores by the length of generated sequences."""
|
|
||||||
|
|
||||||
search: Optional[BeamSearch] = None
|
|
||||||
"""The beam search algorithm to use."""
|
|
||||||
|
|
||||||
logits_processor: Optional[LogitsProcessor] = None
|
|
||||||
"""Logits processor called before applying beam search step."""
|
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqGenerator:
|
|
||||||
"""Represents a sequence-to-sequence generator."""
|
|
||||||
|
|
||||||
decoder: Seq2SeqDecoder
|
|
||||||
opts: SequenceGeneratorOptions
|
|
||||||
beam_size: int
|
|
||||||
eos_idx: int
|
|
||||||
pad_idx: Optional[int]
|
|
||||||
unk_idx: Optional[int]
|
|
||||||
prefix_seq: Union[int, Tensor]
|
|
||||||
prefix_seq_len: int
|
|
||||||
search: BeamSearch
|
|
||||||
logits_processor: Optional[LogitsProcessor]
|
|
||||||
collater: Collater
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
decoder: Seq2SeqDecoder,
|
|
||||||
vocab_info: VocabularyInfo,
|
|
||||||
prefix_seq: Optional[Union[int, Tensor]],
|
|
||||||
opts: Optional[SequenceGeneratorOptions] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
:param decoder:
|
|
||||||
The decoder to use.
|
|
||||||
:param vocab_info:
|
|
||||||
The vocabulary information to use.
|
|
||||||
:param prefix_seq:
|
|
||||||
The prefix sequence, typically one or more control symbols
|
|
||||||
indicating the beginning of a sequence. *Shape:* :math:`()` or
|
|
||||||
:math:`(S)`, where :math:`S` is the sequence length. If ``None``,
|
|
||||||
the EOS symbol will be used as prefix.
|
|
||||||
:param opts:
|
|
||||||
The generation options.
|
|
||||||
"""
|
|
||||||
self.decoder = decoder
|
|
||||||
|
|
||||||
self.opts = opts or SequenceGeneratorOptions()
|
|
||||||
|
|
||||||
# Set beam size.
|
|
||||||
if vocab_info.pad_idx is None:
|
|
||||||
self.beam_size = min(self.opts.beam_size, vocab_info.size)
|
|
||||||
else:
|
|
||||||
# -1 since we never select PAD.
|
|
||||||
self.beam_size = min(self.opts.beam_size, vocab_info.size - 1)
|
|
||||||
|
|
||||||
if vocab_info.eos_idx is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`vocab_info` must have `eos_idx` set for sequence generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set vocab info.
|
|
||||||
self.eos_idx = 1
|
|
||||||
#self.eos_idx = vocab_info.eos_idx
|
|
||||||
self.unk_idx = 2
|
|
||||||
#self.unk_idx = vocab_info.unk_idx
|
|
||||||
self.pad_idx = 0
|
|
||||||
#self.pad_idx = vocab_info.pad_idx
|
|
||||||
|
|
||||||
# Set prefix sequence.
|
|
||||||
if 1:
|
|
||||||
#if prefix_seq is None:
|
|
||||||
# If `None`, we follow fairseq's convention, and use EOS as the
|
|
||||||
# prefix.
|
|
||||||
self.prefix_seq, self.prefix_seq_len = self.eos_idx, 1
|
|
||||||
else:
|
|
||||||
self.prefix_seq = prefix_seq
|
|
||||||
|
|
||||||
if isinstance(prefix_seq, Tensor):
|
|
||||||
num_dim = prefix_seq.dim()
|
|
||||||
|
|
||||||
if num_dim >= 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"`prefix_seq` must be a scalar or a 1-dimensional tensor, but is {num_dim}-dimensional instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.prefix_seq_len = 1 if num_dim == 0 else prefix_seq.size(0)
|
|
||||||
else:
|
|
||||||
self.prefix_seq_len = 1
|
|
||||||
|
|
||||||
# Set beam search.
|
|
||||||
self.search = self.opts.search or StandardBeamSearch()
|
|
||||||
self.logits_processor = self.opts.logits_processor
|
|
||||||
|
|
||||||
if vocab_info.pad_idx is None:
|
|
||||||
self.collater = Collater()
|
|
||||||
else:
|
|
||||||
self.collater = Collater(self.pad_idx, pad_to_multiple=2)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
encoder_output: Tensor,
|
|
||||||
encoder_padding_mask: Optional[Tensor],
|
|
||||||
source_seq_len: Optional[int] = None,
|
|
||||||
) -> "SequenceGeneratorOutput":
|
|
||||||
opts = self.opts
|
|
||||||
|
|
||||||
num_searches = encoder_output.size(0)
|
|
||||||
|
|
||||||
beam_size = opts.beam_size
|
|
||||||
|
|
||||||
max_seq_len = self._determine_max_seq_len(source_seq_len)
|
|
||||||
|
|
||||||
device = encoder_output.device
|
|
||||||
|
|
||||||
encoder_output, encoder_padding_mask = self._fan_out_encoder_output(
|
|
||||||
encoder_output, encoder_padding_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# Each element contains the id of the search corresponding to a single
|
|
||||||
# source sequence and its hypotheses.
|
|
||||||
active_searches: List[Tuple[int, List[Hypothesis]]] = [
|
|
||||||
(search_idx, []) for search_idx in range(num_searches)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Once a source sequence has `beam_size` hypotheses, its search is moved
|
|
||||||
# from `active_searches` to `finished_searches`.
|
|
||||||
finished_searches: List[List[Hypothesis]] = [[] for i in range(num_searches)]
|
|
||||||
|
|
||||||
num_remaining_searches = num_searches
|
|
||||||
|
|
||||||
# Initialize buffers.
|
|
||||||
# (N x B, S)
|
|
||||||
seqs = torch.zeros(
|
|
||||||
(num_searches * beam_size, max_seq_len), device=device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
|
|
||||||
# (N x B, S)
|
|
||||||
scores = torch.zeros(
|
|
||||||
(num_searches * beam_size, max_seq_len), device=device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
# A list that indicates beams that should be ignored in the next step.
|
|
||||||
ignored_beam_mask = torch.full(
|
|
||||||
(num_searches, beam_size), False, device=device, dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
# An offset array for converting between batch-wide and search-local
|
|
||||||
# beam indices.
|
|
||||||
# (B)
|
|
||||||
search_offsets = torch.arange(num_searches, device=device) * beam_size
|
|
||||||
|
|
||||||
# (B) -> (B, 1)
|
|
||||||
search_offsets.unsqueeze_(-1)
|
|
||||||
|
|
||||||
cand_offsets = torch.arange(2 * beam_size, device=device)
|
|
||||||
|
|
||||||
state_bag = IncrementalStateBag()
|
|
||||||
|
|
||||||
# At this point, the state is fully initialized, kick off the search.
|
|
||||||
self._bootstrap_seqs_and_scores(
|
|
||||||
seqs, scores, encoder_output, encoder_padding_mask, state_bag
|
|
||||||
)
|
|
||||||
|
|
||||||
start_step = self.prefix_seq_len - 1
|
|
||||||
|
|
||||||
# Holds the indices of beams (a beam can occur more than once) that we
|
|
||||||
# should continue with in the next step.
|
|
||||||
beam_indices: Optional[Tensor] = None
|
|
||||||
|
|
||||||
# Holds the indices of searches that we should continue with in the next
|
|
||||||
# step. If not `None`, it means we finalized one or more searches in the
|
|
||||||
# last step.
|
|
||||||
search_indices: Optional[Tensor] = None
|
|
||||||
|
|
||||||
for step_nr in range(start_step, max_seq_len - 1):
|
|
||||||
if beam_indices is not None:
|
|
||||||
# If not `None`, it means in the last step we finalized one or
|
|
||||||
# more searches. We should ensure that we adjust `beam_indices`
|
|
||||||
# before reordering `decoder`'s incremental state.
|
|
||||||
if search_indices is not None:
|
|
||||||
num_searches = search_indices.numel()
|
|
||||||
|
|
||||||
# (N)
|
|
||||||
delta = search_indices - torch.arange(num_searches, device=device)
|
|
||||||
|
|
||||||
# (N) -> (N, 1)
|
|
||||||
delta.unsqueeze_(-1)
|
|
||||||
|
|
||||||
# Adjust indices to take into account removed searches.
|
|
||||||
beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
|
|
||||||
|
|
||||||
state_bag.reorder(beam_indices)
|
|
||||||
|
|
||||||
decoder_output, decoder_padding_mask = self.decoder.decode(
|
|
||||||
seqs[:, step_nr : step_nr + 1],
|
|
||||||
None, # We never generate PAD.
|
|
||||||
encoder_output,
|
|
||||||
encoder_padding_mask,
|
|
||||||
state_bag,
|
|
||||||
)
|
|
||||||
|
|
||||||
state_bag.increment_step()
|
|
||||||
|
|
||||||
model_output = self.decoder.project(decoder_output, decoder_padding_mask)
|
|
||||||
|
|
||||||
# lprobs: (1, V)
|
|
||||||
# model_output: (N, 1, V)
|
|
||||||
lprobs = log_softmax(model_output.logits, dim=-1, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Do not allow EOS before reaching the minimum sequence length.
|
|
||||||
if step_nr < self.opts.min_seq_len:
|
|
||||||
lprobs[:, :, self.eos_idx] = -torch.inf
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# If we have reached the maximum length, force the last step to be
|
|
||||||
# EOS.
|
|
||||||
if step_nr == max_seq_len - 2:
|
|
||||||
lprobs[:, :, : self.eos_idx] = -torch.inf
|
|
||||||
lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# Never allow PAD.
|
|
||||||
if self.pad_idx is not None:
|
|
||||||
lprobs[:, :, self.pad_idx] = -torch.inf
|
|
||||||
|
|
||||||
# Apply UNK penalty.
|
|
||||||
if self.unk_idx is not None:
|
|
||||||
lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
|
|
||||||
|
|
||||||
# update scores in place using logits_processor
|
|
||||||
if self.logits_processor is not None:
|
|
||||||
self.logits_processor(
|
|
||||||
seqs.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
|
|
||||||
lprobs.view(num_searches, beam_size, -1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine candidates for the next step.
|
|
||||||
# (N, 2 x B)
|
|
||||||
cand_scores, cand_indices, cand_beam_indices = self.search.step(
|
|
||||||
step_nr,
|
|
||||||
step_nr == start_step,
|
|
||||||
lprobs.view(num_searches, beam_size, -1),
|
|
||||||
scores.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert search-local beam indices to batch-wide beam indices.
|
|
||||||
# (N, 2 x B) + (N) -> (N, 2 x B)
|
|
||||||
global_cand_beam_indices = cand_beam_indices + search_offsets
|
|
||||||
|
|
||||||
# Finalize beams that reached the minimum length and that end with
|
|
||||||
# an EOS.
|
|
||||||
# (N, 2 x B)
|
|
||||||
eos_mask = (cand_indices == self.eos_idx) & (cand_scores != -math.inf)
|
|
||||||
|
|
||||||
# Do not attempt to finalize beams that should be ignored.
|
|
||||||
eos_mask[:, :beam_size][ignored_beam_mask] = False
|
|
||||||
|
|
||||||
# Only consider EOS when it's among the top `beam_size` indices. Now
|
|
||||||
# we know what beam(s) to finalize.
|
|
||||||
# (N, B)
|
|
||||||
eos_beam_indices = torch.masked_select(
|
|
||||||
global_cand_beam_indices[:, :beam_size], mask=eos_mask[:, :beam_size]
|
|
||||||
)
|
|
||||||
|
|
||||||
if eos_beam_indices.numel() > 0:
|
|
||||||
# Select the scores of the finalized beams.
|
|
||||||
# (N, B)
|
|
||||||
eos_scores = torch.masked_select(
|
|
||||||
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
|
||||||
)
|
|
||||||
|
|
||||||
newly_finished_searches = self._finalize_hypothesis(
|
|
||||||
step_nr,
|
|
||||||
eos_beam_indices,
|
|
||||||
eos_scores,
|
|
||||||
seqs,
|
|
||||||
scores,
|
|
||||||
active_searches,
|
|
||||||
finished_searches,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_remaining_searches -= len(newly_finished_searches)
|
|
||||||
|
|
||||||
if num_remaining_searches == 0:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
newly_finished_searches = None
|
|
||||||
|
|
||||||
# Remove finished searches (ones for which `beam_size` finalized
|
|
||||||
# beams have been generated) from the batch.
|
|
||||||
if newly_finished_searches:
|
|
||||||
new_num_searches = num_searches - len(newly_finished_searches)
|
|
||||||
|
|
||||||
# Construct `search_indices` which holds indices of searches
|
|
||||||
# to keep for the next step.
|
|
||||||
search_mask = torch.full((num_searches,), True, device=device)
|
|
||||||
|
|
||||||
search_mask[newly_finished_searches] = False
|
|
||||||
|
|
||||||
search_indices = torch.arange(num_searches, device=device)
|
|
||||||
|
|
||||||
search_indices = search_indices.masked_select(search_mask)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# Filter out removed batches from state variables.
|
|
||||||
# (N, B) -> (N - F, B)
|
|
||||||
ignored_beam_mask = ignored_beam_mask[search_indices]
|
|
||||||
|
|
||||||
# (N, 2 x B) -> (N - F, 2 x B)
|
|
||||||
cand_scores = cand_scores [search_indices]
|
|
||||||
cand_indices = cand_indices [search_indices]
|
|
||||||
cand_beam_indices = cand_beam_indices[search_indices]
|
|
||||||
|
|
||||||
# (N) -> (N - F)
|
|
||||||
search_offsets.resize_(new_num_searches, 1)
|
|
||||||
|
|
||||||
# (N - F, 2 x B) + (N - F) -> (N - F, 2 x B)
|
|
||||||
global_cand_beam_indices = cand_beam_indices + search_offsets
|
|
||||||
|
|
||||||
# (N, 2 x B) -> (N - F, 2 x B)
|
|
||||||
eos_mask = eos_mask[search_indices]
|
|
||||||
|
|
||||||
# (N x B, S) -> (N, B, S)
|
|
||||||
seqs = seqs .view(num_searches, -1)
|
|
||||||
scores = scores.view(num_searches, -1)
|
|
||||||
|
|
||||||
# (N, B, S + 1) -> ((N - F) x B, S)
|
|
||||||
seqs = seqs [search_indices].view(new_num_searches * beam_size, -1)
|
|
||||||
scores = scores[search_indices].view(new_num_searches * beam_size, -1)
|
|
||||||
|
|
||||||
# (N x B, S_enc, M) -> (N, B, S_enc, M)
|
|
||||||
encoder_output = encoder_output.unflatten(0, (num_searches, -1))
|
|
||||||
|
|
||||||
# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
|
|
||||||
encoder_output = encoder_output[search_indices].flatten(0, 1)
|
|
||||||
|
|
||||||
if encoder_padding_mask is not None:
|
|
||||||
# (N x B, S_enc, M) -> (N, B, S_enc, M)
|
|
||||||
padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1))
|
|
||||||
|
|
||||||
# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
|
|
||||||
encoder_padding_mask = padding_mask[search_indices].flatten(0, 1)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
num_searches = new_num_searches
|
|
||||||
else:
|
|
||||||
search_indices = None
|
|
||||||
|
|
||||||
eos_mask[:, :beam_size][ignored_beam_mask] = True
|
|
||||||
|
|
||||||
# Set `beam_weights` so that values greater than or equal to 2 x
|
|
||||||
# `beam_size` indicate finished beams (i.e. end with EOS) and values
|
|
||||||
# less than 2 x `beam_size` indicate active beams.
|
|
||||||
# (N, 2 x B)
|
|
||||||
beam_weights = cand_offsets + (eos_mask * (2 * beam_size))
|
|
||||||
|
|
||||||
# Get the top `beam_size` active beams, which are the beams with the
|
|
||||||
# smallest weights in `active_beam_weights`.
|
|
||||||
# (N, B)
|
|
||||||
active_beam_weights, active_beams = torch.topk(
|
|
||||||
beam_weights, k=beam_size, dim=1, largest=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update to ignore finalized beams in the next step.
|
|
||||||
# (N, B)
|
|
||||||
ignored_beam_mask = active_beam_weights >= 2 * beam_size
|
|
||||||
|
|
||||||
# We should always have at least one active beam in each search.
|
|
||||||
assert (~ignored_beam_mask).any(dim=1).all()
|
|
||||||
|
|
||||||
# Denotes which beams are continued for each new hypothesis (a beam
|
|
||||||
# can be selected more than once).
|
|
||||||
# (N, B)
|
|
||||||
beam_indices = torch.gather(
|
|
||||||
global_cand_beam_indices, dim=1, index=active_beams
|
|
||||||
)
|
|
||||||
|
|
||||||
# (N, B) -> (N x B)
|
|
||||||
beam_indices = beam_indices.view(-1)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# Reorder beams in the `seq` and `score` buffers. The same beam can
|
|
||||||
# be selected more than once.
|
|
||||||
if step_nr > start_step:
|
|
||||||
seqs [:, : step_nr + 1] = torch.index_select(
|
|
||||||
seqs [:, : step_nr + 1], dim=0, index=beam_indices
|
|
||||||
)
|
|
||||||
scores[:, : step_nr + 1] = torch.index_select(
|
|
||||||
scores[:, : step_nr + 1], dim=0, index=beam_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
# (N x B, S) -> (N, B, S)
|
|
||||||
seqs_view = seqs .view(num_searches, beam_size, -1)
|
|
||||||
scores_view = scores.view(num_searches, beam_size, -1)
|
|
||||||
|
|
||||||
seqs_view [:, :, step_nr + 1] = torch.gather(cand_indices, dim=1, index=active_beams)
|
|
||||||
scores_view[:, :, step_nr + 1] = torch.gather(cand_scores, dim=1, index=active_beams)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# Ensure that hypotheses are sorted by their scores before returning.
|
|
||||||
for batch in finished_searches:
|
|
||||||
batch.sort(key=lambda b: b.score, reverse=True) # type: ignore[arg-type, return-value]
|
|
||||||
|
|
||||||
return SequenceGeneratorOutput(
|
|
||||||
results=finished_searches, device=device, collater=self.collater
|
|
||||||
)
|
|
||||||
|
|
||||||
def _determine_max_seq_len(self, source_seq_len: Optional[int]) -> int:
|
|
||||||
opts = self.opts
|
|
||||||
|
|
||||||
if source_seq_len is None or opts.soft_max_seq_len is None:
|
|
||||||
max_seq_len = opts.hard_max_seq_len
|
|
||||||
else:
|
|
||||||
at, bt = opts.soft_max_seq_len
|
|
||||||
|
|
||||||
max_seq_len = min(opts.hard_max_seq_len, int(at * source_seq_len + bt))
|
|
||||||
|
|
||||||
if opts.min_seq_len > max_seq_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"The effective maximum sequence length must be greater than or equal to `min_seq_len` ({opts.min_seq_len}), but is {max_seq_len} instead. Adjust your soft and hard maximum sequence length limits."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.prefix_seq_len >= max_seq_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"The effective maximum sequence length must be greater than `prefix_seq_len` ({self.prefix_seq_len}), but is {max_seq_len} instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
return max_seq_len
|
|
||||||
|
|
||||||
def _fan_out_encoder_output(
|
|
||||||
self, encoder_output: Tensor, encoder_padding_mask: Optional[Tensor]
|
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
||||||
num_searches = encoder_output.size(0) # i.e. batch size
|
|
||||||
|
|
||||||
# Fan out `encoder_output` to `num_searches` x `beam_size`.
|
|
||||||
# (N)
|
|
||||||
fan_out_indices = torch.arange(num_searches, device=encoder_output.device)
|
|
||||||
|
|
||||||
# (N) -> (N x B)
|
|
||||||
fan_out_indices = fan_out_indices.repeat_interleave(self.beam_size)
|
|
||||||
|
|
||||||
# (N, S_enc, M) -> (N x B, S_enc, M)
|
|
||||||
encoder_output = encoder_output.index_select(dim=0, index=fan_out_indices)
|
|
||||||
|
|
||||||
# (N, S_enc, M) -> (N x B, S_enc, M)
|
|
||||||
if encoder_padding_mask is not None:
|
|
||||||
encoder_padding_mask = encoder_padding_mask.index_select(
|
|
||||||
dim=0, index=fan_out_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
return encoder_output, encoder_padding_mask
|
|
||||||
|
|
||||||
def _bootstrap_seqs_and_scores(
|
|
||||||
self,
|
|
||||||
seqs: Tensor,
|
|
||||||
scores: Tensor,
|
|
||||||
encoder_output: Tensor,
|
|
||||||
encoder_padding_mask: Optional[Tensor],
|
|
||||||
state_bag: IncrementalStateBag,
|
|
||||||
) -> None:
|
|
||||||
assert self.prefix_seq_len > 0
|
|
||||||
|
|
||||||
seqs[:, : self.prefix_seq_len] = self.prefix_seq
|
|
||||||
|
|
||||||
if self.prefix_seq_len == 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
assert isinstance(self.prefix_seq, Tensor)
|
|
||||||
|
|
||||||
# We have to bootstrap the model with the already fanned-out encoder
|
|
||||||
# output to correctly initialize its incremental state. This causes some
|
|
||||||
# redundancy as we have to expand `decoder_input` to match the shape of
|
|
||||||
# `encoder_output`.
|
|
||||||
# (S_pfx) -> (N x B, S_pfx - 1)
|
|
||||||
decoder_input = self.prefix_seq[:-1].expand(encoder_output.size(0), -1)
|
|
||||||
|
|
||||||
# Bootstrap the model state with prefix sequence.
|
|
||||||
decoder_output, decoder_padding_mask = self.decoder.decode(
|
|
||||||
decoder_input,
|
|
||||||
None,
|
|
||||||
encoder_output,
|
|
||||||
encoder_padding_mask,
|
|
||||||
state_bag,
|
|
||||||
)
|
|
||||||
|
|
||||||
state_bag.increment_step(self.prefix_seq_len - 1)
|
|
||||||
|
|
||||||
model_output = self.decoder.project(decoder_output, decoder_padding_mask)
|
|
||||||
|
|
||||||
# lprobs: (S_pfx - 1, V)
|
|
||||||
# model_output: (N, S_pfx - 1, V) -> (S_pfx - 1, V)
|
|
||||||
lprobs = log_softmax(model_output.logits[0], dim=-1, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Fetch scores of next steps.
|
|
||||||
# (S_pfx - 1, 1)
|
|
||||||
prefix_scores = torch.take_along_dim(
|
|
||||||
lprobs, indices=self.prefix_seq[1:].unsqueeze(1), dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
# (S_pfx - 1, 1) -> (S_pfx - 1)
|
|
||||||
prefix_scores.squeeze_(1).cumsum_(dim=0)
|
|
||||||
|
|
||||||
# First step (e.g. EOS)'s score is always 0.
|
|
||||||
scores[:, 1 : self.prefix_seq_len] = prefix_scores
|
|
||||||
|
|
||||||
def _finalize_hypothesis(
|
|
||||||
self,
|
|
||||||
step_nr: int,
|
|
||||||
eos_beam_indices: Tensor,
|
|
||||||
eos_scores: Tensor,
|
|
||||||
seqs: Tensor,
|
|
||||||
scores: Tensor,
|
|
||||||
active_searches: List[Tuple[int, List["Hypothesis"]]],
|
|
||||||
finished_searches: List[List["Hypothesis"]],
|
|
||||||
) -> List[int]:
|
|
||||||
# fmt: off
|
|
||||||
finalized_seqs = seqs .index_select(dim=0, index=eos_beam_indices)
|
|
||||||
finalized_scores = scores.index_select(dim=0, index=eos_beam_indices)
|
|
||||||
|
|
||||||
finalized_seqs = finalized_seqs [:, : step_nr + 2]
|
|
||||||
finalized_scores = finalized_scores[:, : step_nr + 2]
|
|
||||||
|
|
||||||
# Finalize beams.
|
|
||||||
finalized_seqs [:, -1] = self.eos_idx
|
|
||||||
finalized_scores[:, -1] = eos_scores
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# Convert from cumulative to per-step scores.
|
|
||||||
finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
|
|
||||||
|
|
||||||
# Skip first EOS since it is always 0 and skews normalization.
|
|
||||||
if self.opts.normalize_scores:
|
|
||||||
eos_scores /= (step_nr + 1) ** self.opts.len_penalty
|
|
||||||
|
|
||||||
# Holds the ids of finished searches.
|
|
||||||
newly_finished: List[int] = []
|
|
||||||
|
|
||||||
active_search_indices = (eos_beam_indices // self.beam_size).tolist()
|
|
||||||
|
|
||||||
for beam_idx, search_idx in enumerate(active_search_indices):
|
|
||||||
search_id, hypotheses = active_searches[search_idx]
|
|
||||||
|
|
||||||
# We might have more than one beam finalized in one step that would
|
|
||||||
# potentially exceed `beam_size` hypotheses.
|
|
||||||
if len(hypotheses) == self.beam_size:
|
|
||||||
continue
|
|
||||||
|
|
||||||
hypotheses.append(
|
|
||||||
Hypothesis(
|
|
||||||
seq=finalized_seqs[beam_idx],
|
|
||||||
score=eos_scores[beam_idx],
|
|
||||||
step_scores=finalized_scores[beam_idx],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(hypotheses) == self.beam_size:
|
|
||||||
# We have `beam_size` hypotheses for this particular search, so
|
|
||||||
# we finish it now.
|
|
||||||
newly_finished.append(search_idx)
|
|
||||||
|
|
||||||
finished_searches[search_id] = hypotheses
|
|
||||||
|
|
||||||
newly_finished.sort()
|
|
||||||
|
|
||||||
# Remove finished searches from the active list.
|
|
||||||
for idx in reversed(newly_finished):
|
|
||||||
del active_searches[idx]
|
|
||||||
|
|
||||||
return newly_finished
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SequenceGeneratorOutput:
|
|
||||||
"""Holds the output of a sequence generator."""
|
|
||||||
|
|
||||||
results: List[List["Hypothesis"]]
|
|
||||||
"""The list of hypothesis generated per search, ordered by score."""
|
|
||||||
|
|
||||||
device: Device
|
|
||||||
"""The device on which generated sequences reside."""
|
|
||||||
|
|
||||||
collater: Optional[Collater] = None
|
|
||||||
"""The collater to use in :meth:`collate`."""
|
|
||||||
|
|
||||||
def collate(
|
|
||||||
self, hypo_idx: int = 0, skip_batch: bool = False
|
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
||||||
"""Collate the generated sequences at index ``hypo_idx`` in each search
|
|
||||||
result into a single tensor.
|
|
||||||
|
|
||||||
:param hypo_idx:
|
|
||||||
The index of hypothesis to extract from each search result.
|
|
||||||
:param skip_batch:
|
|
||||||
If ``True``, if a search result has no hypothesis at index `hypo_idx`,
|
|
||||||
it will be skipped instead of raising an error.
|
|
||||||
|
|
||||||
:returns:
|
|
||||||
- The collated sequences. *Shape:* :math:`(N,S)`, where :math:`N` is
|
|
||||||
the number of search results and :math:`S` is the sequence length.
|
|
||||||
- An array where each element represents the length of the sequence at
|
|
||||||
the same index in the first returned value. *Shape:* :math:`(N)`,
|
|
||||||
where :math:`N` is the number of search results.
|
|
||||||
"""
|
|
||||||
if self.collater is None:
|
|
||||||
raise RuntimeError("The output has no associated `Collater` instance.")
|
|
||||||
|
|
||||||
if not self.results and not skip_batch:
|
|
||||||
raise ValueError("The output must contain at least one search result.")
|
|
||||||
|
|
||||||
seqs = []
|
|
||||||
|
|
||||||
for search_idx, result in enumerate(self.results):
|
|
||||||
if hypo_idx >= len(result):
|
|
||||||
if not skip_batch:
|
|
||||||
raise ValueError(
|
|
||||||
f"Each search result must have at least {hypo_idx + 1} hypotheses, but search {search_idx} has only {len(result)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
seqs.append(result[hypo_idx].seq)
|
|
||||||
|
|
||||||
if not seqs:
|
|
||||||
# Return a zero-dimensional (not scalar!) tensor.
|
|
||||||
return torch.empty((0,), device=self.device, dtype=torch.int64), None
|
|
||||||
|
|
||||||
output = cast(SequenceData, self.collater(seqs))
|
|
||||||
|
|
||||||
return output["seqs"], output["seq_lens"] if output["is_ragged"] else None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Hypothesis:
|
|
||||||
"""Represents a hypothesis produced by a sequence generator."""
|
|
||||||
|
|
||||||
seq: Tensor
|
|
||||||
"""The generated sequence."""
|
|
||||||
|
|
||||||
score: Tensor
|
|
||||||
"""The score of the hypothesis."""
|
|
||||||
|
|
||||||
step_scores: Tensor
|
|
||||||
"""The score of each individual sequence step."""
|
|
@ -1,6 +0,0 @@
|
|||||||
#k2
|
|
||||||
kaldialign
|
|
||||||
lhotse
|
|
||||||
sentencepiece
|
|
||||||
tensorboard
|
|
||||||
fairseq2
|
|
@ -1,43 +0,0 @@
|
|||||||
|
|
||||||
#import sentencepiece as spm
|
|
||||||
|
|
||||||
class CharTokenizer(object):
|
|
||||||
def __init__(self, tokenizer_file):
|
|
||||||
self.id2symbol = {}
|
|
||||||
self.symbol2id = {}
|
|
||||||
with open(tokenizer_file, 'r') as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line:
|
|
||||||
symbol, id = line.split()
|
|
||||||
id = int(id)
|
|
||||||
self.id2symbol[id] = symbol
|
|
||||||
self.symbol2id[symbol] = id
|
|
||||||
self.vocab_size = len(self.id2symbol)
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
# if symbol not in self.symbol2id, using <unk>'s id
|
|
||||||
return [self.symbol2id.get(symbol, 2) for symbol in text]
|
|
||||||
|
|
||||||
def decode(self, ids):
|
|
||||||
return ''.join([self.id2symbol[id] for id in ids])
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# config_file = './config.yaml'
|
|
||||||
# config = read_yaml(config_file)
|
|
||||||
# converter = TokenIDConverter(config['token_list'])
|
|
||||||
# ids = converter.tokens2ids(['<s>', '你', '好', '吗', '</s>', 'microsoft', 'world'])
|
|
||||||
# print(ids)
|
|
||||||
# print(converter.ids2tokens(ids))
|
|
||||||
|
|
||||||
|
|
||||||
tokenizer = CharTokenizer('./tokens.txt')
|
|
||||||
ids = tokenizer.encode('今天 天气不错')
|
|
||||||
print(ids)
|
|
||||||
print(tokenizer.decode(ids+[1]))
|
|
||||||
# sp = spm.SentencePieceProcessor()
|
|
||||||
# sp.Load('../../../librispeech/ASR/k2fsa-zipformer-chinese-english-mixed/data/lang_char_bpe/bpe.model')
|
|
||||||
# texts = ['MICROSOFT WORLD']
|
|
||||||
# y = sp.encode(texts, out_type=int)
|
|
||||||
# x = sp.decode(y)
|
|
||||||
# print(y, x)
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -30,7 +30,7 @@ from lhotse.dataset import (
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SimpleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
@ -176,13 +176,13 @@ class AishellAsrDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--enable-musan",
|
"--enable-musan",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=True,
|
||||||
help="When enabled, select noise from MUSAN and mix it"
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"with training dataset. ",
|
"with training dataset. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, rank = None, world_size = None
|
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -192,13 +192,13 @@ class AishellAsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
|
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
@ -276,12 +276,10 @@ class AishellAsrDataModule:
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
world_size=world_size,
|
|
||||||
rank=rank,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SingleCutSampler.")
|
||||||
train_sampler = SimpleCutSampler(
|
train_sampler = SingleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
@ -302,7 +300,7 @@ class AishellAsrDataModule:
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet, rank = None, world_size = None) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
transforms = [
|
transforms = [
|
||||||
@ -327,8 +325,6 @@ class AishellAsrDataModule:
|
|||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
|
@ -473,10 +473,11 @@ def main():
|
|||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell.test_cuts()
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
test_sets = ["test"]
|
#test_sets = ["test"]
|
||||||
test_dls = [test_dl]
|
#test_dls = [test_dl]
|
||||||
|
test_sets = ["valid"]
|
||||||
|
test_dls = [valid_dl]
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
|
@ -27,7 +27,7 @@
|
|||||||
"params": {
|
"params": {
|
||||||
"warmup_min_lr": 0,
|
"warmup_min_lr": 0,
|
||||||
"warmup_max_lr": 1e-5,
|
"warmup_max_lr": 1e-5,
|
||||||
"warmup_num_steps": 1000
|
"warmup_num_steps": 100
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
|
@ -126,7 +126,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=10,
|
||||||
help="Number of epochs to train.",
|
help="Number of epochs to train.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user