mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
remove seamless for next PR
This commit is contained in:
parent
ac53222054
commit
e883bb60d4
@ -1,9 +0,0 @@
|
||||
|
||||
#export CUDA_VISIBLE_DEVICES="2,3"
|
||||
#pip install -r seamlessm4t/requirements.txt
|
||||
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
|
||||
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
|
||||
python3 seamlessm4t/decode.py --epoch 5 --exp-dir seamlessm4t/exp
|
||||
python3 seamlessm4t/decode.py --epoch 5 --avg 2 --exp-dir seamlessm4t/exp
|
@ -1,8 +0,0 @@
|
||||
|
||||
#export CUDA_VISIBLE_DEVICES="1"
|
||||
#pip install -r whisper/requirements.txt
|
||||
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
||||
#export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall/
|
||||
|
||||
python3 whisper/decode.py --exp-dir whisper/exp --max-duration 100
|
@ -1,8 +0,0 @@
|
||||
|
||||
#export CUDA_VISIBLE_DEVICES="2,3"
|
||||
pip install -r seamlessm4t/requirements.txt
|
||||
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
|
||||
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
|
||||
torchrun --nproc-per-node 8 seamlessm4t/train.py --use-fp16 1 --max-duration 300 --base-lr 1e-5 --exp-dir seamlessm4t/exp_new_vocab --start-epoch 1
|
@ -1,9 +0,0 @@
|
||||
|
||||
|
||||
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
pip install -r whisper/requirements.txt
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
|
||||
#export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
||||
#export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall
|
||||
|
||||
torchrun --nproc-per-node 8 whisper/train.py --use-fp16 1 --max-duration 20 --base-lr 1e-5 --exp-dir whisper/exp_medimum --start-epoch 1
|
@ -1 +0,0 @@
|
||||
../tdnn_lstm_ctc/asr_datamodule.py
|
@ -1,415 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
#from conformer import Conformer
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from seamless_communication.models.unity import (
|
||||
UnitYModel,
|
||||
load_unity_model,
|
||||
load_unity_text_tokenizer,
|
||||
)
|
||||
from fairseq2.generation import (
|
||||
SequenceGeneratorOptions,
|
||||
SequenceToTextGenerator,
|
||||
)
|
||||
from seamless_communication.models.unity.model import UnitYX2TModel
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="beam-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
|
||||
tokens using token symbol tabel directly.
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) attention-decoder. Extract n paths from the lattice,
|
||||
the path with the highest score is the decoding result.
|
||||
- (4) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="seamlessm4t/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
# parameters for conformer
|
||||
"subsampling_factor": 4,
|
||||
"feature_dim": 80,
|
||||
"nhead": 4,
|
||||
"attention_dim": 512,
|
||||
"num_encoder_layers": 12,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"search_beam": 20,
|
||||
"output_beam": 7,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
s2t_generator: SequenceToTextGenerator,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if decoding method is 1best, the key is the string `no_rescore`.
|
||||
If attention rescoring is used, the key is the string
|
||||
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
|
||||
value of `lm_scale` and `attention_scale`. An example key is
|
||||
`ngram_lm_scale_0.7_attention_scale_0.5`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "attention-decoder", it uses attention rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains the token symbol table and the word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda", 3)
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
|
||||
text_output = s2t_generator.generate_ex(feature, feature_len)
|
||||
sentences = text_output.sentences
|
||||
hyps = [sentence.bytes().decode("utf-8").split() for sentence in sentences]
|
||||
key = "beam-search"
|
||||
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
s2t_generator: SequenceToTextGenerator,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
lexicon:
|
||||
It contains the token symbol table and the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if the decoding method is
|
||||
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
|
||||
rescoring is used. Its value is a list of tuples. Each tuple contains two
|
||||
elements: The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
s2t_generator=s2t_generator,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 3)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
dtype = torch.float16
|
||||
|
||||
model_name_or_card = "seamlessM4T_medium"
|
||||
#model_name_or_card = "seamlessM4T_large"
|
||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
||||
del model.t2u_model
|
||||
del model.text_encoder
|
||||
del model.text_encoder_frontend
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
|
||||
text_max_len_a = 1
|
||||
text_max_len_b = 200
|
||||
target_lang = "cmn"
|
||||
|
||||
text_opts = SequenceGeneratorOptions(
|
||||
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
||||
)
|
||||
|
||||
s2t_model = UnitYX2TModel(
|
||||
encoder_frontend=model.speech_encoder_frontend,
|
||||
encoder=model.speech_encoder,
|
||||
decoder_frontend=model.text_decoder_frontend,
|
||||
decoder=model.text_decoder,
|
||||
final_proj=model.final_proj,
|
||||
pad_idx=model.pad_idx,
|
||||
)
|
||||
s2t_generator = SequenceToTextGenerator(
|
||||
s2t_model, text_tokenizer, target_lang, text_opts
|
||||
)
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
s2t_generator=s2t_generator,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,432 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
#from conformer import Conformer
|
||||
from tokenizer import CharTokenizer
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
from seamless_communication.models.unity import (
|
||||
UnitYModel,
|
||||
load_unity_model,
|
||||
load_unity_text_tokenizer,
|
||||
)
|
||||
from fairseq2.generation import (
|
||||
SequenceGeneratorOptions,
|
||||
SequenceToTextGenerator,
|
||||
)
|
||||
from seamless_communication.models.unity.model import UnitYX2TModel
|
||||
from fairseq2.nn.embedding import Embedding
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="beam-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
|
||||
tokens using token symbol tabel directly.
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) attention-decoder. Extract n paths from the lattice,
|
||||
the path with the highest score is the decoding result.
|
||||
- (4) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="seamlessm4t/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
# parameters for conformer
|
||||
"subsampling_factor": 4,
|
||||
"feature_dim": 80,
|
||||
"nhead": 4,
|
||||
"attention_dim": 512,
|
||||
"num_encoder_layers": 12,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"search_beam": 20,
|
||||
"output_beam": 7,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
s2t_generator: SequenceToTextGenerator,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if decoding method is 1best, the key is the string `no_rescore`.
|
||||
If attention rescoring is used, the key is the string
|
||||
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
|
||||
value of `lm_scale` and `attention_scale`. An example key is
|
||||
`ngram_lm_scale_0.7_attention_scale_0.5`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "attention-decoder", it uses attention rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains the token symbol table and the word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda", 3)
|
||||
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device, dtype=dtype)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
|
||||
text_output = s2t_generator.generate_ex(feature, feature_len)
|
||||
#sentences = text_output.sentences
|
||||
#hyps = [sentence.bytes().decode("utf-8").split() for sentence in sentences]
|
||||
|
||||
token_ids = text_output.generator_output.results
|
||||
hyps_ids = [sentence[0].seq.cpu().tolist() for sentence in token_ids]
|
||||
hyps = [params.tokenizer.decode(hyps_id).split() for hyps_id in hyps_ids]
|
||||
|
||||
key = "beam-search"
|
||||
|
||||
return {key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
s2t_generator: SequenceToTextGenerator,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
lexicon:
|
||||
It contains the token symbol table and the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if the decoding method is
|
||||
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
|
||||
rescoring is used. Its value is a list of tuples. Each tuple contains two
|
||||
elements: The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
s2t_generator=s2t_generator,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.tokenizer = CharTokenizer('./seamlessm4t/tokens.txt')
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 3)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
dtype = torch.float16
|
||||
|
||||
model_name_or_card = "seamlessM4T_medium"
|
||||
#model_name_or_card = "seamlessM4T_large"
|
||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
||||
del model.t2u_model
|
||||
del model.text_encoder
|
||||
del model.text_encoder_frontend
|
||||
model.text_decoder_frontend.embed = nn.Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,padding_idx=0)
|
||||
#model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
||||
model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size, bias=False)
|
||||
#model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size)
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.half()
|
||||
#for param in model.parameters():
|
||||
# if param.dtype == torch.float16:
|
||||
# pass
|
||||
# else:
|
||||
# param.data = param.data.to(torch.float16)
|
||||
#print(param)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
|
||||
text_max_len_a = 1
|
||||
text_max_len_b = 200
|
||||
target_lang = "cmn"
|
||||
|
||||
text_opts = SequenceGeneratorOptions(
|
||||
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
||||
)
|
||||
|
||||
s2t_model = UnitYX2TModel(
|
||||
encoder_frontend=model.speech_encoder_frontend,
|
||||
encoder=model.speech_encoder,
|
||||
decoder_frontend=model.text_decoder_frontend,
|
||||
decoder=model.text_decoder,
|
||||
final_proj=model.final_proj,
|
||||
pad_idx=model.pad_idx,
|
||||
)
|
||||
s2t_generator = SequenceToTextGenerator(
|
||||
s2t_model, text_tokenizer, target_lang, text_opts
|
||||
)
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
s2t_generator=s2t_generator,
|
||||
)
|
||||
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
@ -1,133 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq2.nn.embedding import Embedding
|
||||
from seamless_communication.models.inference import Translator
|
||||
from seamless_communication.models.unity import (
|
||||
UnitTokenizer,
|
||||
UnitYModel,
|
||||
load_unity_model,
|
||||
load_unity_text_tokenizer,
|
||||
load_unity_unit_tokenizer,
|
||||
)
|
||||
from fairseq2.generation import (
|
||||
Seq2SeqGenerator,
|
||||
SequenceGeneratorOptions,
|
||||
SequenceGeneratorOutput,
|
||||
SequenceToTextGenerator,
|
||||
SequenceToTextOutput,
|
||||
)
|
||||
from seamless_communication.models.unity.model import UnitYModel, UnitYX2TModel
|
||||
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_en/wav/1089-134686-0001.wav"
|
||||
src_lang="cmn"
|
||||
|
||||
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_zh/wav/long.wav"
|
||||
src_lang="eng"
|
||||
target_lang = "cmn"
|
||||
|
||||
audio_input = torchaudio.load(audio_file)[0]
|
||||
feature = ta_kaldi.fbank(audio_input, num_mel_bins=80)
|
||||
# feature shape is (T, F), convert it to (B, T, F), source_seq_lens tracks T
|
||||
source_seqs = feature.unsqueeze(0)
|
||||
source_seq_lens = torch.tensor([feature.shape[0]])
|
||||
|
||||
# Initialize a Translator object with a multitask model, vocoder on the GPU.
|
||||
|
||||
|
||||
# translator = Translator("seamlessM4T_medium", vocoder_name_or_card="vocoder_36langs", device=torch.device("cuda:2"), dtype=torch.float16)
|
||||
|
||||
# transcribed_text, _, _ = translator.predict(audio_file, "asr", src_lang)
|
||||
|
||||
# print(transcribed_text)
|
||||
|
||||
|
||||
model_name_or_card = "seamlessM4T_medium"
|
||||
device = torch.device("cuda:3")
|
||||
|
||||
# cast source_seq_lens, source_seqs to device, dtype to torch.float16
|
||||
source_seq_lens = source_seq_lens.to(device=device, dtype=torch.float16)
|
||||
source_seqs = source_seqs.to(device=device, dtype=torch.float16)
|
||||
|
||||
|
||||
|
||||
dtype = torch.float16
|
||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
||||
model.eval()
|
||||
model.text_decoder_frontend.embed = Embedding(num_embeddings=6257, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
||||
model.final_proj = nn.Linear(1024, 6257)
|
||||
model.half()
|
||||
print(model.text_decoder_frontend.embed, model.text_encoder_frontend.embed.weight.dtype, type(model.text_encoder_frontend.embed), type(model.text_encoder_frontend.embed.weight))
|
||||
print(model.final_proj, model.final_proj.weight.dtype, type(model.final_proj), type(model.final_proj.weight))
|
||||
#input()
|
||||
exit(0)
|
||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
#print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx)
|
||||
#text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target")
|
||||
#text_tokenizer_decoder = text_tokenizer.create_decoder()
|
||||
# print attritbut of text_tokenizer_encoder
|
||||
#print(text_tokenizer.vocab_info)
|
||||
#print(text_tokenizer_encoder("其中广州深圳甚至出现了多个日光盘"))
|
||||
#print(text_tokenizer_decoder(torch.tensor([3,256200,137139,252603,250476,250590,1,84778,148897,249568,249352,249947,249050,250520,254508])))
|
||||
|
||||
# store all vocab in a file
|
||||
# with open("vocab.txt", "w") as f:
|
||||
# for i in range(256206):
|
||||
# f.write(f"{i}: " + text_tokenizer_decoder(torch.tensor([i]))[0].bytes().decode("utf-8")+ "\n")
|
||||
# f.close()
|
||||
# exit(0)
|
||||
|
||||
|
||||
|
||||
# def decode(
|
||||
# self,
|
||||
# seqs: Tensor,
|
||||
# seq_lens: Optional[Tensor],
|
||||
# encoder_output: Tensor,
|
||||
# encoder_padding_mask: Optional[Tensor],
|
||||
# state_bag: Optional[IncrementalStateBag] = None,
|
||||
# ) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
# seqs, padding_mask = self.text_decoder_frontend(seqs, seq_lens, state_bag)
|
||||
|
||||
# return self.text_decoder( # type: ignore[no-any-return]
|
||||
# seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
|
||||
# )
|
||||
|
||||
# def decoding(model, feature):
|
||||
# seqs, padding_mask = model.speech_encoder_frontend(seqs, seq_lens)
|
||||
# speech_encoder(seqs, padding_mask)
|
||||
|
||||
# decoder_output, decoder_padding_mask = self.decode(
|
||||
# batch.target_seqs,
|
||||
# batch.target_seq_lens,
|
||||
# encoder_output,
|
||||
# encoder_padding_mask,
|
||||
# )
|
||||
|
||||
# text_logits = model.final_project(decoder_output, decoder_padding_mask)
|
||||
|
||||
text_max_len_a = 1
|
||||
text_max_len_b = 200
|
||||
|
||||
text_opts = SequenceGeneratorOptions(
|
||||
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
||||
)
|
||||
|
||||
s2t_model = UnitYX2TModel(
|
||||
encoder_frontend=model.speech_encoder_frontend,
|
||||
encoder=model.speech_encoder,
|
||||
decoder_frontend=model.text_decoder_frontend,
|
||||
decoder=model.text_decoder,
|
||||
final_proj=model.final_proj,
|
||||
pad_idx=model.pad_idx,
|
||||
)
|
||||
s2t_generator = SequenceToTextGenerator(
|
||||
s2t_model, text_tokenizer, target_lang, text_opts
|
||||
)
|
||||
|
||||
text_output = s2t_generator.generate_ex(source_seqs, source_seq_lens)
|
||||
print(text_output.generator_output.results[0][0].seq.cpu().tolist())
|
||||
# sentence = text_output.sentences[0]
|
||||
# print(sentence, type(sentence))
|
||||
# sentence = sentence.bytes().decode("utf-8")
|
File diff suppressed because it is too large
Load Diff
@ -1,694 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.functional import log_softmax
|
||||
|
||||
from fairseq2.data import Collater, SequenceData, VocabularyInfo
|
||||
from fairseq2.generation.beam_search import BeamSearch, StandardBeamSearch
|
||||
from fairseq2.generation.logits_processor import LogitsProcessor
|
||||
from fairseq2.models.encoder_decoder import Seq2SeqDecoder
|
||||
from fairseq2.nn.incremental_state import IncrementalStateBag
|
||||
from fairseq2.typing import Device
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGeneratorOptions:
|
||||
"""Holds the options to pass to a sequence generator."""
|
||||
|
||||
beam_size: int = 5
|
||||
"""The beam size."""
|
||||
|
||||
min_seq_len: int = 1
|
||||
"""The minimum length of generated sequences (including prefix sequence)."""
|
||||
|
||||
soft_max_seq_len: Optional[Tuple[int, int]] = (1, 200)
|
||||
"""The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source
|
||||
sequence length. The generated sequences (including prefix sequence) will
|
||||
have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also
|
||||
``hard_max_seq_len``."""
|
||||
|
||||
hard_max_seq_len: int = 1024
|
||||
"""The hard limit on maximum length of generated sequences."""
|
||||
|
||||
len_penalty: float = 1.0
|
||||
"""The length penalty, where values less than 1.0 favor shorter, values
|
||||
greater than 1.0 favor longer sequences."""
|
||||
|
||||
unk_penalty: float = 0.0
|
||||
"""The unknown symbol penalty, where values less than 0 produce more UNKs,
|
||||
values greater than 0 produce fewer UNKs."""
|
||||
|
||||
normalize_scores: bool = True
|
||||
"""If ``True``, normalizes scores by the length of generated sequences."""
|
||||
|
||||
search: Optional[BeamSearch] = None
|
||||
"""The beam search algorithm to use."""
|
||||
|
||||
logits_processor: Optional[LogitsProcessor] = None
|
||||
"""Logits processor called before applying beam search step."""
|
||||
|
||||
|
||||
class Seq2SeqGenerator:
|
||||
"""Represents a sequence-to-sequence generator."""
|
||||
|
||||
decoder: Seq2SeqDecoder
|
||||
opts: SequenceGeneratorOptions
|
||||
beam_size: int
|
||||
eos_idx: int
|
||||
pad_idx: Optional[int]
|
||||
unk_idx: Optional[int]
|
||||
prefix_seq: Union[int, Tensor]
|
||||
prefix_seq_len: int
|
||||
search: BeamSearch
|
||||
logits_processor: Optional[LogitsProcessor]
|
||||
collater: Collater
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: Seq2SeqDecoder,
|
||||
vocab_info: VocabularyInfo,
|
||||
prefix_seq: Optional[Union[int, Tensor]],
|
||||
opts: Optional[SequenceGeneratorOptions] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param decoder:
|
||||
The decoder to use.
|
||||
:param vocab_info:
|
||||
The vocabulary information to use.
|
||||
:param prefix_seq:
|
||||
The prefix sequence, typically one or more control symbols
|
||||
indicating the beginning of a sequence. *Shape:* :math:`()` or
|
||||
:math:`(S)`, where :math:`S` is the sequence length. If ``None``,
|
||||
the EOS symbol will be used as prefix.
|
||||
:param opts:
|
||||
The generation options.
|
||||
"""
|
||||
self.decoder = decoder
|
||||
|
||||
self.opts = opts or SequenceGeneratorOptions()
|
||||
|
||||
# Set beam size.
|
||||
if vocab_info.pad_idx is None:
|
||||
self.beam_size = min(self.opts.beam_size, vocab_info.size)
|
||||
else:
|
||||
# -1 since we never select PAD.
|
||||
self.beam_size = min(self.opts.beam_size, vocab_info.size - 1)
|
||||
|
||||
if vocab_info.eos_idx is None:
|
||||
raise ValueError(
|
||||
"`vocab_info` must have `eos_idx` set for sequence generation."
|
||||
)
|
||||
|
||||
# Set vocab info.
|
||||
self.eos_idx = 1
|
||||
#self.eos_idx = vocab_info.eos_idx
|
||||
self.unk_idx = 2
|
||||
#self.unk_idx = vocab_info.unk_idx
|
||||
self.pad_idx = 0
|
||||
#self.pad_idx = vocab_info.pad_idx
|
||||
|
||||
# Set prefix sequence.
|
||||
if 1:
|
||||
#if prefix_seq is None:
|
||||
# If `None`, we follow fairseq's convention, and use EOS as the
|
||||
# prefix.
|
||||
self.prefix_seq, self.prefix_seq_len = self.eos_idx, 1
|
||||
else:
|
||||
self.prefix_seq = prefix_seq
|
||||
|
||||
if isinstance(prefix_seq, Tensor):
|
||||
num_dim = prefix_seq.dim()
|
||||
|
||||
if num_dim >= 2:
|
||||
raise ValueError(
|
||||
f"`prefix_seq` must be a scalar or a 1-dimensional tensor, but is {num_dim}-dimensional instead."
|
||||
)
|
||||
|
||||
self.prefix_seq_len = 1 if num_dim == 0 else prefix_seq.size(0)
|
||||
else:
|
||||
self.prefix_seq_len = 1
|
||||
|
||||
# Set beam search.
|
||||
self.search = self.opts.search or StandardBeamSearch()
|
||||
self.logits_processor = self.opts.logits_processor
|
||||
|
||||
if vocab_info.pad_idx is None:
|
||||
self.collater = Collater()
|
||||
else:
|
||||
self.collater = Collater(self.pad_idx, pad_to_multiple=2)
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self,
|
||||
encoder_output: Tensor,
|
||||
encoder_padding_mask: Optional[Tensor],
|
||||
source_seq_len: Optional[int] = None,
|
||||
) -> "SequenceGeneratorOutput":
|
||||
opts = self.opts
|
||||
|
||||
num_searches = encoder_output.size(0)
|
||||
|
||||
beam_size = opts.beam_size
|
||||
|
||||
max_seq_len = self._determine_max_seq_len(source_seq_len)
|
||||
|
||||
device = encoder_output.device
|
||||
|
||||
encoder_output, encoder_padding_mask = self._fan_out_encoder_output(
|
||||
encoder_output, encoder_padding_mask
|
||||
)
|
||||
|
||||
# Each element contains the id of the search corresponding to a single
|
||||
# source sequence and its hypotheses.
|
||||
active_searches: List[Tuple[int, List[Hypothesis]]] = [
|
||||
(search_idx, []) for search_idx in range(num_searches)
|
||||
]
|
||||
|
||||
# Once a source sequence has `beam_size` hypotheses, its search is moved
|
||||
# from `active_searches` to `finished_searches`.
|
||||
finished_searches: List[List[Hypothesis]] = [[] for i in range(num_searches)]
|
||||
|
||||
num_remaining_searches = num_searches
|
||||
|
||||
# Initialize buffers.
|
||||
# (N x B, S)
|
||||
seqs = torch.zeros(
|
||||
(num_searches * beam_size, max_seq_len), device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
# (N x B, S)
|
||||
scores = torch.zeros(
|
||||
(num_searches * beam_size, max_seq_len), device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
# A list that indicates beams that should be ignored in the next step.
|
||||
ignored_beam_mask = torch.full(
|
||||
(num_searches, beam_size), False, device=device, dtype=torch.bool
|
||||
)
|
||||
|
||||
# An offset array for converting between batch-wide and search-local
|
||||
# beam indices.
|
||||
# (B)
|
||||
search_offsets = torch.arange(num_searches, device=device) * beam_size
|
||||
|
||||
# (B) -> (B, 1)
|
||||
search_offsets.unsqueeze_(-1)
|
||||
|
||||
cand_offsets = torch.arange(2 * beam_size, device=device)
|
||||
|
||||
state_bag = IncrementalStateBag()
|
||||
|
||||
# At this point, the state is fully initialized, kick off the search.
|
||||
self._bootstrap_seqs_and_scores(
|
||||
seqs, scores, encoder_output, encoder_padding_mask, state_bag
|
||||
)
|
||||
|
||||
start_step = self.prefix_seq_len - 1
|
||||
|
||||
# Holds the indices of beams (a beam can occur more than once) that we
|
||||
# should continue with in the next step.
|
||||
beam_indices: Optional[Tensor] = None
|
||||
|
||||
# Holds the indices of searches that we should continue with in the next
|
||||
# step. If not `None`, it means we finalized one or more searches in the
|
||||
# last step.
|
||||
search_indices: Optional[Tensor] = None
|
||||
|
||||
for step_nr in range(start_step, max_seq_len - 1):
|
||||
if beam_indices is not None:
|
||||
# If not `None`, it means in the last step we finalized one or
|
||||
# more searches. We should ensure that we adjust `beam_indices`
|
||||
# before reordering `decoder`'s incremental state.
|
||||
if search_indices is not None:
|
||||
num_searches = search_indices.numel()
|
||||
|
||||
# (N)
|
||||
delta = search_indices - torch.arange(num_searches, device=device)
|
||||
|
||||
# (N) -> (N, 1)
|
||||
delta.unsqueeze_(-1)
|
||||
|
||||
# Adjust indices to take into account removed searches.
|
||||
beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
|
||||
|
||||
state_bag.reorder(beam_indices)
|
||||
|
||||
decoder_output, decoder_padding_mask = self.decoder.decode(
|
||||
seqs[:, step_nr : step_nr + 1],
|
||||
None, # We never generate PAD.
|
||||
encoder_output,
|
||||
encoder_padding_mask,
|
||||
state_bag,
|
||||
)
|
||||
|
||||
state_bag.increment_step()
|
||||
|
||||
model_output = self.decoder.project(decoder_output, decoder_padding_mask)
|
||||
|
||||
# lprobs: (1, V)
|
||||
# model_output: (N, 1, V)
|
||||
lprobs = log_softmax(model_output.logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
# Do not allow EOS before reaching the minimum sequence length.
|
||||
if step_nr < self.opts.min_seq_len:
|
||||
lprobs[:, :, self.eos_idx] = -torch.inf
|
||||
|
||||
# fmt: off
|
||||
# If we have reached the maximum length, force the last step to be
|
||||
# EOS.
|
||||
if step_nr == max_seq_len - 2:
|
||||
lprobs[:, :, : self.eos_idx] = -torch.inf
|
||||
lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
|
||||
# fmt: on
|
||||
|
||||
# Never allow PAD.
|
||||
if self.pad_idx is not None:
|
||||
lprobs[:, :, self.pad_idx] = -torch.inf
|
||||
|
||||
# Apply UNK penalty.
|
||||
if self.unk_idx is not None:
|
||||
lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
|
||||
|
||||
# update scores in place using logits_processor
|
||||
if self.logits_processor is not None:
|
||||
self.logits_processor(
|
||||
seqs.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
|
||||
lprobs.view(num_searches, beam_size, -1),
|
||||
)
|
||||
|
||||
# Determine candidates for the next step.
|
||||
# (N, 2 x B)
|
||||
cand_scores, cand_indices, cand_beam_indices = self.search.step(
|
||||
step_nr,
|
||||
step_nr == start_step,
|
||||
lprobs.view(num_searches, beam_size, -1),
|
||||
scores.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
|
||||
)
|
||||
|
||||
# Convert search-local beam indices to batch-wide beam indices.
|
||||
# (N, 2 x B) + (N) -> (N, 2 x B)
|
||||
global_cand_beam_indices = cand_beam_indices + search_offsets
|
||||
|
||||
# Finalize beams that reached the minimum length and that end with
|
||||
# an EOS.
|
||||
# (N, 2 x B)
|
||||
eos_mask = (cand_indices == self.eos_idx) & (cand_scores != -math.inf)
|
||||
|
||||
# Do not attempt to finalize beams that should be ignored.
|
||||
eos_mask[:, :beam_size][ignored_beam_mask] = False
|
||||
|
||||
# Only consider EOS when it's among the top `beam_size` indices. Now
|
||||
# we know what beam(s) to finalize.
|
||||
# (N, B)
|
||||
eos_beam_indices = torch.masked_select(
|
||||
global_cand_beam_indices[:, :beam_size], mask=eos_mask[:, :beam_size]
|
||||
)
|
||||
|
||||
if eos_beam_indices.numel() > 0:
|
||||
# Select the scores of the finalized beams.
|
||||
# (N, B)
|
||||
eos_scores = torch.masked_select(
|
||||
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
||||
)
|
||||
|
||||
newly_finished_searches = self._finalize_hypothesis(
|
||||
step_nr,
|
||||
eos_beam_indices,
|
||||
eos_scores,
|
||||
seqs,
|
||||
scores,
|
||||
active_searches,
|
||||
finished_searches,
|
||||
)
|
||||
|
||||
num_remaining_searches -= len(newly_finished_searches)
|
||||
|
||||
if num_remaining_searches == 0:
|
||||
break
|
||||
else:
|
||||
newly_finished_searches = None
|
||||
|
||||
# Remove finished searches (ones for which `beam_size` finalized
|
||||
# beams have been generated) from the batch.
|
||||
if newly_finished_searches:
|
||||
new_num_searches = num_searches - len(newly_finished_searches)
|
||||
|
||||
# Construct `search_indices` which holds indices of searches
|
||||
# to keep for the next step.
|
||||
search_mask = torch.full((num_searches,), True, device=device)
|
||||
|
||||
search_mask[newly_finished_searches] = False
|
||||
|
||||
search_indices = torch.arange(num_searches, device=device)
|
||||
|
||||
search_indices = search_indices.masked_select(search_mask)
|
||||
|
||||
# fmt: off
|
||||
# Filter out removed batches from state variables.
|
||||
# (N, B) -> (N - F, B)
|
||||
ignored_beam_mask = ignored_beam_mask[search_indices]
|
||||
|
||||
# (N, 2 x B) -> (N - F, 2 x B)
|
||||
cand_scores = cand_scores [search_indices]
|
||||
cand_indices = cand_indices [search_indices]
|
||||
cand_beam_indices = cand_beam_indices[search_indices]
|
||||
|
||||
# (N) -> (N - F)
|
||||
search_offsets.resize_(new_num_searches, 1)
|
||||
|
||||
# (N - F, 2 x B) + (N - F) -> (N - F, 2 x B)
|
||||
global_cand_beam_indices = cand_beam_indices + search_offsets
|
||||
|
||||
# (N, 2 x B) -> (N - F, 2 x B)
|
||||
eos_mask = eos_mask[search_indices]
|
||||
|
||||
# (N x B, S) -> (N, B, S)
|
||||
seqs = seqs .view(num_searches, -1)
|
||||
scores = scores.view(num_searches, -1)
|
||||
|
||||
# (N, B, S + 1) -> ((N - F) x B, S)
|
||||
seqs = seqs [search_indices].view(new_num_searches * beam_size, -1)
|
||||
scores = scores[search_indices].view(new_num_searches * beam_size, -1)
|
||||
|
||||
# (N x B, S_enc, M) -> (N, B, S_enc, M)
|
||||
encoder_output = encoder_output.unflatten(0, (num_searches, -1))
|
||||
|
||||
# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
|
||||
encoder_output = encoder_output[search_indices].flatten(0, 1)
|
||||
|
||||
if encoder_padding_mask is not None:
|
||||
# (N x B, S_enc, M) -> (N, B, S_enc, M)
|
||||
padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1))
|
||||
|
||||
# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
|
||||
encoder_padding_mask = padding_mask[search_indices].flatten(0, 1)
|
||||
# fmt: on
|
||||
|
||||
num_searches = new_num_searches
|
||||
else:
|
||||
search_indices = None
|
||||
|
||||
eos_mask[:, :beam_size][ignored_beam_mask] = True
|
||||
|
||||
# Set `beam_weights` so that values greater than or equal to 2 x
|
||||
# `beam_size` indicate finished beams (i.e. end with EOS) and values
|
||||
# less than 2 x `beam_size` indicate active beams.
|
||||
# (N, 2 x B)
|
||||
beam_weights = cand_offsets + (eos_mask * (2 * beam_size))
|
||||
|
||||
# Get the top `beam_size` active beams, which are the beams with the
|
||||
# smallest weights in `active_beam_weights`.
|
||||
# (N, B)
|
||||
active_beam_weights, active_beams = torch.topk(
|
||||
beam_weights, k=beam_size, dim=1, largest=False
|
||||
)
|
||||
|
||||
# Update to ignore finalized beams in the next step.
|
||||
# (N, B)
|
||||
ignored_beam_mask = active_beam_weights >= 2 * beam_size
|
||||
|
||||
# We should always have at least one active beam in each search.
|
||||
assert (~ignored_beam_mask).any(dim=1).all()
|
||||
|
||||
# Denotes which beams are continued for each new hypothesis (a beam
|
||||
# can be selected more than once).
|
||||
# (N, B)
|
||||
beam_indices = torch.gather(
|
||||
global_cand_beam_indices, dim=1, index=active_beams
|
||||
)
|
||||
|
||||
# (N, B) -> (N x B)
|
||||
beam_indices = beam_indices.view(-1)
|
||||
|
||||
# fmt: off
|
||||
# Reorder beams in the `seq` and `score` buffers. The same beam can
|
||||
# be selected more than once.
|
||||
if step_nr > start_step:
|
||||
seqs [:, : step_nr + 1] = torch.index_select(
|
||||
seqs [:, : step_nr + 1], dim=0, index=beam_indices
|
||||
)
|
||||
scores[:, : step_nr + 1] = torch.index_select(
|
||||
scores[:, : step_nr + 1], dim=0, index=beam_indices
|
||||
)
|
||||
|
||||
# (N x B, S) -> (N, B, S)
|
||||
seqs_view = seqs .view(num_searches, beam_size, -1)
|
||||
scores_view = scores.view(num_searches, beam_size, -1)
|
||||
|
||||
seqs_view [:, :, step_nr + 1] = torch.gather(cand_indices, dim=1, index=active_beams)
|
||||
scores_view[:, :, step_nr + 1] = torch.gather(cand_scores, dim=1, index=active_beams)
|
||||
# fmt: on
|
||||
|
||||
# Ensure that hypotheses are sorted by their scores before returning.
|
||||
for batch in finished_searches:
|
||||
batch.sort(key=lambda b: b.score, reverse=True) # type: ignore[arg-type, return-value]
|
||||
|
||||
return SequenceGeneratorOutput(
|
||||
results=finished_searches, device=device, collater=self.collater
|
||||
)
|
||||
|
||||
def _determine_max_seq_len(self, source_seq_len: Optional[int]) -> int:
|
||||
opts = self.opts
|
||||
|
||||
if source_seq_len is None or opts.soft_max_seq_len is None:
|
||||
max_seq_len = opts.hard_max_seq_len
|
||||
else:
|
||||
at, bt = opts.soft_max_seq_len
|
||||
|
||||
max_seq_len = min(opts.hard_max_seq_len, int(at * source_seq_len + bt))
|
||||
|
||||
if opts.min_seq_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The effective maximum sequence length must be greater than or equal to `min_seq_len` ({opts.min_seq_len}), but is {max_seq_len} instead. Adjust your soft and hard maximum sequence length limits."
|
||||
)
|
||||
|
||||
if self.prefix_seq_len >= max_seq_len:
|
||||
raise ValueError(
|
||||
f"The effective maximum sequence length must be greater than `prefix_seq_len` ({self.prefix_seq_len}), but is {max_seq_len} instead."
|
||||
)
|
||||
|
||||
return max_seq_len
|
||||
|
||||
def _fan_out_encoder_output(
|
||||
self, encoder_output: Tensor, encoder_padding_mask: Optional[Tensor]
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
num_searches = encoder_output.size(0) # i.e. batch size
|
||||
|
||||
# Fan out `encoder_output` to `num_searches` x `beam_size`.
|
||||
# (N)
|
||||
fan_out_indices = torch.arange(num_searches, device=encoder_output.device)
|
||||
|
||||
# (N) -> (N x B)
|
||||
fan_out_indices = fan_out_indices.repeat_interleave(self.beam_size)
|
||||
|
||||
# (N, S_enc, M) -> (N x B, S_enc, M)
|
||||
encoder_output = encoder_output.index_select(dim=0, index=fan_out_indices)
|
||||
|
||||
# (N, S_enc, M) -> (N x B, S_enc, M)
|
||||
if encoder_padding_mask is not None:
|
||||
encoder_padding_mask = encoder_padding_mask.index_select(
|
||||
dim=0, index=fan_out_indices
|
||||
)
|
||||
|
||||
return encoder_output, encoder_padding_mask
|
||||
|
||||
def _bootstrap_seqs_and_scores(
|
||||
self,
|
||||
seqs: Tensor,
|
||||
scores: Tensor,
|
||||
encoder_output: Tensor,
|
||||
encoder_padding_mask: Optional[Tensor],
|
||||
state_bag: IncrementalStateBag,
|
||||
) -> None:
|
||||
assert self.prefix_seq_len > 0
|
||||
|
||||
seqs[:, : self.prefix_seq_len] = self.prefix_seq
|
||||
|
||||
if self.prefix_seq_len == 1:
|
||||
return
|
||||
|
||||
assert isinstance(self.prefix_seq, Tensor)
|
||||
|
||||
# We have to bootstrap the model with the already fanned-out encoder
|
||||
# output to correctly initialize its incremental state. This causes some
|
||||
# redundancy as we have to expand `decoder_input` to match the shape of
|
||||
# `encoder_output`.
|
||||
# (S_pfx) -> (N x B, S_pfx - 1)
|
||||
decoder_input = self.prefix_seq[:-1].expand(encoder_output.size(0), -1)
|
||||
|
||||
# Bootstrap the model state with prefix sequence.
|
||||
decoder_output, decoder_padding_mask = self.decoder.decode(
|
||||
decoder_input,
|
||||
None,
|
||||
encoder_output,
|
||||
encoder_padding_mask,
|
||||
state_bag,
|
||||
)
|
||||
|
||||
state_bag.increment_step(self.prefix_seq_len - 1)
|
||||
|
||||
model_output = self.decoder.project(decoder_output, decoder_padding_mask)
|
||||
|
||||
# lprobs: (S_pfx - 1, V)
|
||||
# model_output: (N, S_pfx - 1, V) -> (S_pfx - 1, V)
|
||||
lprobs = log_softmax(model_output.logits[0], dim=-1, dtype=torch.float32)
|
||||
|
||||
# Fetch scores of next steps.
|
||||
# (S_pfx - 1, 1)
|
||||
prefix_scores = torch.take_along_dim(
|
||||
lprobs, indices=self.prefix_seq[1:].unsqueeze(1), dim=-1
|
||||
)
|
||||
|
||||
# (S_pfx - 1, 1) -> (S_pfx - 1)
|
||||
prefix_scores.squeeze_(1).cumsum_(dim=0)
|
||||
|
||||
# First step (e.g. EOS)'s score is always 0.
|
||||
scores[:, 1 : self.prefix_seq_len] = prefix_scores
|
||||
|
||||
def _finalize_hypothesis(
|
||||
self,
|
||||
step_nr: int,
|
||||
eos_beam_indices: Tensor,
|
||||
eos_scores: Tensor,
|
||||
seqs: Tensor,
|
||||
scores: Tensor,
|
||||
active_searches: List[Tuple[int, List["Hypothesis"]]],
|
||||
finished_searches: List[List["Hypothesis"]],
|
||||
) -> List[int]:
|
||||
# fmt: off
|
||||
finalized_seqs = seqs .index_select(dim=0, index=eos_beam_indices)
|
||||
finalized_scores = scores.index_select(dim=0, index=eos_beam_indices)
|
||||
|
||||
finalized_seqs = finalized_seqs [:, : step_nr + 2]
|
||||
finalized_scores = finalized_scores[:, : step_nr + 2]
|
||||
|
||||
# Finalize beams.
|
||||
finalized_seqs [:, -1] = self.eos_idx
|
||||
finalized_scores[:, -1] = eos_scores
|
||||
# fmt: on
|
||||
|
||||
# Convert from cumulative to per-step scores.
|
||||
finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
|
||||
|
||||
# Skip first EOS since it is always 0 and skews normalization.
|
||||
if self.opts.normalize_scores:
|
||||
eos_scores /= (step_nr + 1) ** self.opts.len_penalty
|
||||
|
||||
# Holds the ids of finished searches.
|
||||
newly_finished: List[int] = []
|
||||
|
||||
active_search_indices = (eos_beam_indices // self.beam_size).tolist()
|
||||
|
||||
for beam_idx, search_idx in enumerate(active_search_indices):
|
||||
search_id, hypotheses = active_searches[search_idx]
|
||||
|
||||
# We might have more than one beam finalized in one step that would
|
||||
# potentially exceed `beam_size` hypotheses.
|
||||
if len(hypotheses) == self.beam_size:
|
||||
continue
|
||||
|
||||
hypotheses.append(
|
||||
Hypothesis(
|
||||
seq=finalized_seqs[beam_idx],
|
||||
score=eos_scores[beam_idx],
|
||||
step_scores=finalized_scores[beam_idx],
|
||||
)
|
||||
)
|
||||
|
||||
if len(hypotheses) == self.beam_size:
|
||||
# We have `beam_size` hypotheses for this particular search, so
|
||||
# we finish it now.
|
||||
newly_finished.append(search_idx)
|
||||
|
||||
finished_searches[search_id] = hypotheses
|
||||
|
||||
newly_finished.sort()
|
||||
|
||||
# Remove finished searches from the active list.
|
||||
for idx in reversed(newly_finished):
|
||||
del active_searches[idx]
|
||||
|
||||
return newly_finished
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGeneratorOutput:
|
||||
"""Holds the output of a sequence generator."""
|
||||
|
||||
results: List[List["Hypothesis"]]
|
||||
"""The list of hypothesis generated per search, ordered by score."""
|
||||
|
||||
device: Device
|
||||
"""The device on which generated sequences reside."""
|
||||
|
||||
collater: Optional[Collater] = None
|
||||
"""The collater to use in :meth:`collate`."""
|
||||
|
||||
def collate(
|
||||
self, hypo_idx: int = 0, skip_batch: bool = False
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Collate the generated sequences at index ``hypo_idx`` in each search
|
||||
result into a single tensor.
|
||||
|
||||
:param hypo_idx:
|
||||
The index of hypothesis to extract from each search result.
|
||||
:param skip_batch:
|
||||
If ``True``, if a search result has no hypothesis at index `hypo_idx`,
|
||||
it will be skipped instead of raising an error.
|
||||
|
||||
:returns:
|
||||
- The collated sequences. *Shape:* :math:`(N,S)`, where :math:`N` is
|
||||
the number of search results and :math:`S` is the sequence length.
|
||||
- An array where each element represents the length of the sequence at
|
||||
the same index in the first returned value. *Shape:* :math:`(N)`,
|
||||
where :math:`N` is the number of search results.
|
||||
"""
|
||||
if self.collater is None:
|
||||
raise RuntimeError("The output has no associated `Collater` instance.")
|
||||
|
||||
if not self.results and not skip_batch:
|
||||
raise ValueError("The output must contain at least one search result.")
|
||||
|
||||
seqs = []
|
||||
|
||||
for search_idx, result in enumerate(self.results):
|
||||
if hypo_idx >= len(result):
|
||||
if not skip_batch:
|
||||
raise ValueError(
|
||||
f"Each search result must have at least {hypo_idx + 1} hypotheses, but search {search_idx} has only {len(result)}."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
seqs.append(result[hypo_idx].seq)
|
||||
|
||||
if not seqs:
|
||||
# Return a zero-dimensional (not scalar!) tensor.
|
||||
return torch.empty((0,), device=self.device, dtype=torch.int64), None
|
||||
|
||||
output = cast(SequenceData, self.collater(seqs))
|
||||
|
||||
return output["seqs"], output["seq_lens"] if output["is_ragged"] else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
"""Represents a hypothesis produced by a sequence generator."""
|
||||
|
||||
seq: Tensor
|
||||
"""The generated sequence."""
|
||||
|
||||
score: Tensor
|
||||
"""The score of the hypothesis."""
|
||||
|
||||
step_scores: Tensor
|
||||
"""The score of each individual sequence step."""
|
@ -1,6 +0,0 @@
|
||||
#k2
|
||||
kaldialign
|
||||
lhotse
|
||||
sentencepiece
|
||||
tensorboard
|
||||
fairseq2
|
@ -1,43 +0,0 @@
|
||||
|
||||
#import sentencepiece as spm
|
||||
|
||||
class CharTokenizer(object):
|
||||
def __init__(self, tokenizer_file):
|
||||
self.id2symbol = {}
|
||||
self.symbol2id = {}
|
||||
with open(tokenizer_file, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
symbol, id = line.split()
|
||||
id = int(id)
|
||||
self.id2symbol[id] = symbol
|
||||
self.symbol2id[symbol] = id
|
||||
self.vocab_size = len(self.id2symbol)
|
||||
|
||||
def encode(self, text):
|
||||
# if symbol not in self.symbol2id, using <unk>'s id
|
||||
return [self.symbol2id.get(symbol, 2) for symbol in text]
|
||||
|
||||
def decode(self, ids):
|
||||
return ''.join([self.id2symbol[id] for id in ids])
|
||||
|
||||
if __name__ == '__main__':
|
||||
# config_file = './config.yaml'
|
||||
# config = read_yaml(config_file)
|
||||
# converter = TokenIDConverter(config['token_list'])
|
||||
# ids = converter.tokens2ids(['<s>', '你', '好', '吗', '</s>', 'microsoft', 'world'])
|
||||
# print(ids)
|
||||
# print(converter.ids2tokens(ids))
|
||||
|
||||
|
||||
tokenizer = CharTokenizer('./tokens.txt')
|
||||
ids = tokenizer.encode('今天 天气不错')
|
||||
print(ids)
|
||||
print(tokenizer.decode(ids+[1]))
|
||||
# sp = spm.SentencePieceProcessor()
|
||||
# sp.Load('../../../librispeech/ASR/k2fsa-zipformer-chinese-english-mixed/data/lang_char_bpe/bpe.model')
|
||||
# texts = ['MICROSOFT WORLD']
|
||||
# y = sp.encode(texts, out_type=int)
|
||||
# x = sp.decode(y)
|
||||
# print(y, x)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -30,7 +30,7 @@ from lhotse.dataset import (
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
@ -176,13 +176,13 @@ class AishellAsrDataModule:
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, rank = None, world_size = None
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -192,13 +192,13 @@ class AishellAsrDataModule:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -276,12 +276,10 @@ class AishellAsrDataModule:
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
@ -302,7 +300,7 @@ class AishellAsrDataModule:
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet, rank = None, world_size = None) -> DataLoader:
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
@ -327,8 +325,6 @@ class AishellAsrDataModule:
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
|
@ -473,10 +473,11 @@ def main():
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||
#test_sets = ["test"]
|
||||
#test_dls = [test_dl]
|
||||
test_sets = ["valid"]
|
||||
test_dls = [valid_dl]
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
|
@ -27,7 +27,7 @@
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 1e-5,
|
||||
"warmup_num_steps": 1000
|
||||
"warmup_num_steps": 100
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
|
@ -126,7 +126,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
default=10,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user