remove seamless for next PR

This commit is contained in:
Yuekai Zhang 2024-01-15 19:34:03 +08:00
parent ac53222054
commit e883bb60d4
20 changed files with 15 additions and 11738 deletions

View File

@ -1,9 +0,0 @@
#export CUDA_VISIBLE_DEVICES="2,3"
#pip install -r seamlessm4t/requirements.txt
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
python3 seamlessm4t/decode.py --epoch 5 --exp-dir seamlessm4t/exp
python3 seamlessm4t/decode.py --epoch 5 --avg 2 --exp-dir seamlessm4t/exp

View File

@ -1,8 +0,0 @@
#export CUDA_VISIBLE_DEVICES="1"
#pip install -r whisper/requirements.txt
#pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
#export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall/
python3 whisper/decode.py --exp-dir whisper/exp --max-duration 100

View File

@ -1,8 +0,0 @@
#export CUDA_VISIBLE_DEVICES="2,3"
pip install -r seamlessm4t/requirements.txt
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
torchrun --nproc-per-node 8 seamlessm4t/train.py --use-fp16 1 --max-duration 300 --base-lr 1e-5 --exp-dir seamlessm4t/exp_new_vocab --start-epoch 1

View File

@ -1,9 +0,0 @@
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
pip install -r whisper/requirements.txt
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
#export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
#export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall
torchrun --nproc-per-node 8 whisper/train.py --use-fp16 1 --max-duration 20 --base-lr 1e-5 --exp-dir whisper/exp_medimum --start-epoch 1

View File

@ -1 +0,0 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -1,415 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
#from conformer import Conformer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
)
from seamless_communication.models.unity import (
UnitYModel,
load_unity_model,
load_unity_text_tokenizer,
)
from fairseq2.generation import (
SequenceGeneratorOptions,
SequenceToTextGenerator,
)
from seamless_communication.models.unity.model import UnitYX2TModel
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=-1,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="beam-search",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
tokens using token symbol tabel directly.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) attention-decoder. Extract n paths from the lattice,
the path with the highest score is the decoding result.
- (4) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="seamlessm4t/exp",
help="The experiment dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80,
"nhead": 4,
"attention_dim": 512,
"num_encoder_layers": 12,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"search_beam": 20,
"output_beam": 7,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def decode_one_batch(
params: AttributeDict,
s2t_generator: SequenceToTextGenerator,
batch: dict,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if decoding method is 1best, the key is the string `no_rescore`.
If attention rescoring is used, the key is the string
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
value of `lm_scale` and `attention_scale`. An example key is
`ngram_lm_scale_0.7_attention_scale_0.5`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "attention-decoder", it uses attention rescoring.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
dtype = torch.float16
device = torch.device("cuda", 3)
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
text_output = s2t_generator.generate_ex(feature, feature_len)
sentences = text_output.sentences
hyps = [sentence.bytes().decode("utf-8").split() for sentence in sentences]
key = "beam-search"
return {key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
s2t_generator: SequenceToTextGenerator,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
Return a dict, whose key may be "no-rescore" if the decoding method is
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
rescoring is used. Its value is a list of tuples. Each tuple contains two
elements: The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
s2t_generator=s2t_generator,
batch=batch,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 3)
logging.info(f"device: {device}")
dtype = torch.float16
model_name_or_card = "seamlessM4T_medium"
#model_name_or_card = "seamlessM4T_large"
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
del model.t2u_model
del model.text_encoder
del model.text_encoder_frontend
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
text_max_len_a = 1
text_max_len_b = 200
target_lang = "cmn"
text_opts = SequenceGeneratorOptions(
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
)
s2t_model = UnitYX2TModel(
encoder_frontend=model.speech_encoder_frontend,
encoder=model.speech_encoder,
decoder_frontend=model.text_decoder_frontend,
decoder=model.text_decoder,
final_proj=model.final_proj,
pad_idx=model.pad_idx,
)
s2t_generator = SequenceToTextGenerator(
s2t_model, text_tokenizer, target_lang, text_opts
)
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
s2t_generator=s2t_generator,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,432 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
#from conformer import Conformer
from tokenizer import CharTokenizer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
)
from seamless_communication.models.unity import (
UnitYModel,
load_unity_model,
load_unity_text_tokenizer,
)
from fairseq2.generation import (
SequenceGeneratorOptions,
SequenceToTextGenerator,
)
from seamless_communication.models.unity.model import UnitYX2TModel
from fairseq2.nn.embedding import Embedding
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=-1,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
default="beam-search",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
tokens using token symbol tabel directly.
- (1) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (3) attention-decoder. Extract n paths from the lattice,
the path with the highest score is the decoding result.
- (4) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="seamlessm4t/exp",
help="The experiment dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80,
"nhead": 4,
"attention_dim": 512,
"num_encoder_layers": 12,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"search_beam": 20,
"output_beam": 7,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def decode_one_batch(
params: AttributeDict,
s2t_generator: SequenceToTextGenerator,
batch: dict,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if decoding method is 1best, the key is the string `no_rescore`.
If attention rescoring is used, the key is the string
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
value of `lm_scale` and `attention_scale`. An example key is
`ngram_lm_scale_0.7_attention_scale_0.5`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "attention-decoder", it uses attention rescoring.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
dtype = torch.float16
device = torch.device("cuda", 3)
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
text_output = s2t_generator.generate_ex(feature, feature_len)
#sentences = text_output.sentences
#hyps = [sentence.bytes().decode("utf-8").split() for sentence in sentences]
token_ids = text_output.generator_output.results
hyps_ids = [sentence[0].seq.cpu().tolist() for sentence in token_ids]
hyps = [params.tokenizer.decode(hyps_id).split() for hyps_id in hyps_ids]
key = "beam-search"
return {key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
s2t_generator: SequenceToTextGenerator,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
Return a dict, whose key may be "no-rescore" if the decoding method is
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
rescoring is used. Its value is a list of tuples. Each tuple contains two
elements: The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
s2t_generator=s2t_generator,
batch=batch,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.tokenizer = CharTokenizer('./seamlessm4t/tokens.txt')
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 3)
logging.info(f"device: {device}")
dtype = torch.float16
model_name_or_card = "seamlessM4T_medium"
#model_name_or_card = "seamlessM4T_large"
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
del model.t2u_model
del model.text_encoder
del model.text_encoder_frontend
model.text_decoder_frontend.embed = nn.Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,padding_idx=0)
#model.text_decoder_frontend.embed = Embedding(num_embeddings=params.tokenizer.vocab_size, embedding_dim=1024 ,pad_idx=0, scaled=True)
model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size, bias=False)
#model.final_proj = nn.Linear(1024, params.tokenizer.vocab_size)
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
model.half()
#for param in model.parameters():
# if param.dtype == torch.float16:
# pass
# else:
# param.data = param.data.to(torch.float16)
#print(param)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
text_max_len_a = 1
text_max_len_b = 200
target_lang = "cmn"
text_opts = SequenceGeneratorOptions(
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
)
s2t_model = UnitYX2TModel(
encoder_frontend=model.speech_encoder_frontend,
encoder=model.speech_encoder,
decoder_frontend=model.text_decoder_frontend,
decoder=model.text_decoder,
final_proj=model.final_proj,
pad_idx=model.pad_idx,
)
s2t_generator = SequenceToTextGenerator(
s2t_model, text_tokenizer, target_lang, text_opts
)
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
s2t_generator=s2t_generator,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -1,133 +0,0 @@
import torch
import torch.nn as nn
from fairseq2.nn.embedding import Embedding
from seamless_communication.models.inference import Translator
from seamless_communication.models.unity import (
UnitTokenizer,
UnitYModel,
load_unity_model,
load_unity_text_tokenizer,
load_unity_unit_tokenizer,
)
from fairseq2.generation import (
Seq2SeqGenerator,
SequenceGeneratorOptions,
SequenceGeneratorOutput,
SequenceToTextGenerator,
SequenceToTextOutput,
)
from seamless_communication.models.unity.model import UnitYModel, UnitYX2TModel
import torchaudio
import torchaudio.compliance.kaldi as ta_kaldi
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_en/wav/1089-134686-0001.wav"
src_lang="cmn"
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_zh/wav/long.wav"
src_lang="eng"
target_lang = "cmn"
audio_input = torchaudio.load(audio_file)[0]
feature = ta_kaldi.fbank(audio_input, num_mel_bins=80)
# feature shape is (T, F), convert it to (B, T, F), source_seq_lens tracks T
source_seqs = feature.unsqueeze(0)
source_seq_lens = torch.tensor([feature.shape[0]])
# Initialize a Translator object with a multitask model, vocoder on the GPU.
# translator = Translator("seamlessM4T_medium", vocoder_name_or_card="vocoder_36langs", device=torch.device("cuda:2"), dtype=torch.float16)
# transcribed_text, _, _ = translator.predict(audio_file, "asr", src_lang)
# print(transcribed_text)
model_name_or_card = "seamlessM4T_medium"
device = torch.device("cuda:3")
# cast source_seq_lens, source_seqs to device, dtype to torch.float16
source_seq_lens = source_seq_lens.to(device=device, dtype=torch.float16)
source_seqs = source_seqs.to(device=device, dtype=torch.float16)
dtype = torch.float16
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
model.eval()
model.text_decoder_frontend.embed = Embedding(num_embeddings=6257, embedding_dim=1024 ,pad_idx=0, scaled=True)
model.final_proj = nn.Linear(1024, 6257)
model.half()
print(model.text_decoder_frontend.embed, model.text_encoder_frontend.embed.weight.dtype, type(model.text_encoder_frontend.embed), type(model.text_encoder_frontend.embed.weight))
print(model.final_proj, model.final_proj.weight.dtype, type(model.final_proj), type(model.final_proj.weight))
#input()
exit(0)
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
#print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx)
#text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target")
#text_tokenizer_decoder = text_tokenizer.create_decoder()
# print attritbut of text_tokenizer_encoder
#print(text_tokenizer.vocab_info)
#print(text_tokenizer_encoder("其中广州深圳甚至出现了多个日光盘"))
#print(text_tokenizer_decoder(torch.tensor([3,256200,137139,252603,250476,250590,1,84778,148897,249568,249352,249947,249050,250520,254508])))
# store all vocab in a file
# with open("vocab.txt", "w") as f:
# for i in range(256206):
# f.write(f"{i}: " + text_tokenizer_decoder(torch.tensor([i]))[0].bytes().decode("utf-8")+ "\n")
# f.close()
# exit(0)
# def decode(
# self,
# seqs: Tensor,
# seq_lens: Optional[Tensor],
# encoder_output: Tensor,
# encoder_padding_mask: Optional[Tensor],
# state_bag: Optional[IncrementalStateBag] = None,
# ) -> Tuple[Tensor, Optional[Tensor]]:
# seqs, padding_mask = self.text_decoder_frontend(seqs, seq_lens, state_bag)
# return self.text_decoder( # type: ignore[no-any-return]
# seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
# )
# def decoding(model, feature):
# seqs, padding_mask = model.speech_encoder_frontend(seqs, seq_lens)
# speech_encoder(seqs, padding_mask)
# decoder_output, decoder_padding_mask = self.decode(
# batch.target_seqs,
# batch.target_seq_lens,
# encoder_output,
# encoder_padding_mask,
# )
# text_logits = model.final_project(decoder_output, decoder_padding_mask)
text_max_len_a = 1
text_max_len_b = 200
text_opts = SequenceGeneratorOptions(
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
)
s2t_model = UnitYX2TModel(
encoder_frontend=model.speech_encoder_frontend,
encoder=model.speech_encoder,
decoder_frontend=model.text_decoder_frontend,
decoder=model.text_decoder,
final_proj=model.final_proj,
pad_idx=model.pad_idx,
)
s2t_generator = SequenceToTextGenerator(
s2t_model, text_tokenizer, target_lang, text_opts
)
text_output = s2t_generator.generate_ex(source_seqs, source_seq_lens)
print(text_output.generator_output.results[0][0].seq.cpu().tolist())
# sentence = text_output.sentences[0]
# print(sentence, type(sentence))
# sentence = sentence.bytes().decode("utf-8")

File diff suppressed because it is too large Load Diff

View File

@ -1,694 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union, cast
import torch
from torch import Tensor
from torch.nn.functional import log_softmax
from fairseq2.data import Collater, SequenceData, VocabularyInfo
from fairseq2.generation.beam_search import BeamSearch, StandardBeamSearch
from fairseq2.generation.logits_processor import LogitsProcessor
from fairseq2.models.encoder_decoder import Seq2SeqDecoder
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.typing import Device
@dataclass
class SequenceGeneratorOptions:
"""Holds the options to pass to a sequence generator."""
beam_size: int = 5
"""The beam size."""
min_seq_len: int = 1
"""The minimum length of generated sequences (including prefix sequence)."""
soft_max_seq_len: Optional[Tuple[int, int]] = (1, 200)
"""The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source
sequence length. The generated sequences (including prefix sequence) will
have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also
``hard_max_seq_len``."""
hard_max_seq_len: int = 1024
"""The hard limit on maximum length of generated sequences."""
len_penalty: float = 1.0
"""The length penalty, where values less than 1.0 favor shorter, values
greater than 1.0 favor longer sequences."""
unk_penalty: float = 0.0
"""The unknown symbol penalty, where values less than 0 produce more UNKs,
values greater than 0 produce fewer UNKs."""
normalize_scores: bool = True
"""If ``True``, normalizes scores by the length of generated sequences."""
search: Optional[BeamSearch] = None
"""The beam search algorithm to use."""
logits_processor: Optional[LogitsProcessor] = None
"""Logits processor called before applying beam search step."""
class Seq2SeqGenerator:
"""Represents a sequence-to-sequence generator."""
decoder: Seq2SeqDecoder
opts: SequenceGeneratorOptions
beam_size: int
eos_idx: int
pad_idx: Optional[int]
unk_idx: Optional[int]
prefix_seq: Union[int, Tensor]
prefix_seq_len: int
search: BeamSearch
logits_processor: Optional[LogitsProcessor]
collater: Collater
def __init__(
self,
decoder: Seq2SeqDecoder,
vocab_info: VocabularyInfo,
prefix_seq: Optional[Union[int, Tensor]],
opts: Optional[SequenceGeneratorOptions] = None,
) -> None:
"""
:param decoder:
The decoder to use.
:param vocab_info:
The vocabulary information to use.
:param prefix_seq:
The prefix sequence, typically one or more control symbols
indicating the beginning of a sequence. *Shape:* :math:`()` or
:math:`(S)`, where :math:`S` is the sequence length. If ``None``,
the EOS symbol will be used as prefix.
:param opts:
The generation options.
"""
self.decoder = decoder
self.opts = opts or SequenceGeneratorOptions()
# Set beam size.
if vocab_info.pad_idx is None:
self.beam_size = min(self.opts.beam_size, vocab_info.size)
else:
# -1 since we never select PAD.
self.beam_size = min(self.opts.beam_size, vocab_info.size - 1)
if vocab_info.eos_idx is None:
raise ValueError(
"`vocab_info` must have `eos_idx` set for sequence generation."
)
# Set vocab info.
self.eos_idx = 1
#self.eos_idx = vocab_info.eos_idx
self.unk_idx = 2
#self.unk_idx = vocab_info.unk_idx
self.pad_idx = 0
#self.pad_idx = vocab_info.pad_idx
# Set prefix sequence.
if 1:
#if prefix_seq is None:
# If `None`, we follow fairseq's convention, and use EOS as the
# prefix.
self.prefix_seq, self.prefix_seq_len = self.eos_idx, 1
else:
self.prefix_seq = prefix_seq
if isinstance(prefix_seq, Tensor):
num_dim = prefix_seq.dim()
if num_dim >= 2:
raise ValueError(
f"`prefix_seq` must be a scalar or a 1-dimensional tensor, but is {num_dim}-dimensional instead."
)
self.prefix_seq_len = 1 if num_dim == 0 else prefix_seq.size(0)
else:
self.prefix_seq_len = 1
# Set beam search.
self.search = self.opts.search or StandardBeamSearch()
self.logits_processor = self.opts.logits_processor
if vocab_info.pad_idx is None:
self.collater = Collater()
else:
self.collater = Collater(self.pad_idx, pad_to_multiple=2)
@torch.inference_mode()
def __call__(
self,
encoder_output: Tensor,
encoder_padding_mask: Optional[Tensor],
source_seq_len: Optional[int] = None,
) -> "SequenceGeneratorOutput":
opts = self.opts
num_searches = encoder_output.size(0)
beam_size = opts.beam_size
max_seq_len = self._determine_max_seq_len(source_seq_len)
device = encoder_output.device
encoder_output, encoder_padding_mask = self._fan_out_encoder_output(
encoder_output, encoder_padding_mask
)
# Each element contains the id of the search corresponding to a single
# source sequence and its hypotheses.
active_searches: List[Tuple[int, List[Hypothesis]]] = [
(search_idx, []) for search_idx in range(num_searches)
]
# Once a source sequence has `beam_size` hypotheses, its search is moved
# from `active_searches` to `finished_searches`.
finished_searches: List[List[Hypothesis]] = [[] for i in range(num_searches)]
num_remaining_searches = num_searches
# Initialize buffers.
# (N x B, S)
seqs = torch.zeros(
(num_searches * beam_size, max_seq_len), device=device, dtype=torch.int64
)
# (N x B, S)
scores = torch.zeros(
(num_searches * beam_size, max_seq_len), device=device, dtype=torch.float32
)
# A list that indicates beams that should be ignored in the next step.
ignored_beam_mask = torch.full(
(num_searches, beam_size), False, device=device, dtype=torch.bool
)
# An offset array for converting between batch-wide and search-local
# beam indices.
# (B)
search_offsets = torch.arange(num_searches, device=device) * beam_size
# (B) -> (B, 1)
search_offsets.unsqueeze_(-1)
cand_offsets = torch.arange(2 * beam_size, device=device)
state_bag = IncrementalStateBag()
# At this point, the state is fully initialized, kick off the search.
self._bootstrap_seqs_and_scores(
seqs, scores, encoder_output, encoder_padding_mask, state_bag
)
start_step = self.prefix_seq_len - 1
# Holds the indices of beams (a beam can occur more than once) that we
# should continue with in the next step.
beam_indices: Optional[Tensor] = None
# Holds the indices of searches that we should continue with in the next
# step. If not `None`, it means we finalized one or more searches in the
# last step.
search_indices: Optional[Tensor] = None
for step_nr in range(start_step, max_seq_len - 1):
if beam_indices is not None:
# If not `None`, it means in the last step we finalized one or
# more searches. We should ensure that we adjust `beam_indices`
# before reordering `decoder`'s incremental state.
if search_indices is not None:
num_searches = search_indices.numel()
# (N)
delta = search_indices - torch.arange(num_searches, device=device)
# (N) -> (N, 1)
delta.unsqueeze_(-1)
# Adjust indices to take into account removed searches.
beam_indices.view(num_searches, beam_size).add_(delta * beam_size)
state_bag.reorder(beam_indices)
decoder_output, decoder_padding_mask = self.decoder.decode(
seqs[:, step_nr : step_nr + 1],
None, # We never generate PAD.
encoder_output,
encoder_padding_mask,
state_bag,
)
state_bag.increment_step()
model_output = self.decoder.project(decoder_output, decoder_padding_mask)
# lprobs: (1, V)
# model_output: (N, 1, V)
lprobs = log_softmax(model_output.logits, dim=-1, dtype=torch.float32)
# Do not allow EOS before reaching the minimum sequence length.
if step_nr < self.opts.min_seq_len:
lprobs[:, :, self.eos_idx] = -torch.inf
# fmt: off
# If we have reached the maximum length, force the last step to be
# EOS.
if step_nr == max_seq_len - 2:
lprobs[:, :, : self.eos_idx] = -torch.inf
lprobs[:, :, self.eos_idx + 1 :] = -torch.inf
# fmt: on
# Never allow PAD.
if self.pad_idx is not None:
lprobs[:, :, self.pad_idx] = -torch.inf
# Apply UNK penalty.
if self.unk_idx is not None:
lprobs[:, :, self.unk_idx] -= self.opts.unk_penalty
# update scores in place using logits_processor
if self.logits_processor is not None:
self.logits_processor(
seqs.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
lprobs.view(num_searches, beam_size, -1),
)
# Determine candidates for the next step.
# (N, 2 x B)
cand_scores, cand_indices, cand_beam_indices = self.search.step(
step_nr,
step_nr == start_step,
lprobs.view(num_searches, beam_size, -1),
scores.view(num_searches, beam_size, -1)[:, :, : step_nr + 1],
)
# Convert search-local beam indices to batch-wide beam indices.
# (N, 2 x B) + (N) -> (N, 2 x B)
global_cand_beam_indices = cand_beam_indices + search_offsets
# Finalize beams that reached the minimum length and that end with
# an EOS.
# (N, 2 x B)
eos_mask = (cand_indices == self.eos_idx) & (cand_scores != -math.inf)
# Do not attempt to finalize beams that should be ignored.
eos_mask[:, :beam_size][ignored_beam_mask] = False
# Only consider EOS when it's among the top `beam_size` indices. Now
# we know what beam(s) to finalize.
# (N, B)
eos_beam_indices = torch.masked_select(
global_cand_beam_indices[:, :beam_size], mask=eos_mask[:, :beam_size]
)
if eos_beam_indices.numel() > 0:
# Select the scores of the finalized beams.
# (N, B)
eos_scores = torch.masked_select(
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
)
newly_finished_searches = self._finalize_hypothesis(
step_nr,
eos_beam_indices,
eos_scores,
seqs,
scores,
active_searches,
finished_searches,
)
num_remaining_searches -= len(newly_finished_searches)
if num_remaining_searches == 0:
break
else:
newly_finished_searches = None
# Remove finished searches (ones for which `beam_size` finalized
# beams have been generated) from the batch.
if newly_finished_searches:
new_num_searches = num_searches - len(newly_finished_searches)
# Construct `search_indices` which holds indices of searches
# to keep for the next step.
search_mask = torch.full((num_searches,), True, device=device)
search_mask[newly_finished_searches] = False
search_indices = torch.arange(num_searches, device=device)
search_indices = search_indices.masked_select(search_mask)
# fmt: off
# Filter out removed batches from state variables.
# (N, B) -> (N - F, B)
ignored_beam_mask = ignored_beam_mask[search_indices]
# (N, 2 x B) -> (N - F, 2 x B)
cand_scores = cand_scores [search_indices]
cand_indices = cand_indices [search_indices]
cand_beam_indices = cand_beam_indices[search_indices]
# (N) -> (N - F)
search_offsets.resize_(new_num_searches, 1)
# (N - F, 2 x B) + (N - F) -> (N - F, 2 x B)
global_cand_beam_indices = cand_beam_indices + search_offsets
# (N, 2 x B) -> (N - F, 2 x B)
eos_mask = eos_mask[search_indices]
# (N x B, S) -> (N, B, S)
seqs = seqs .view(num_searches, -1)
scores = scores.view(num_searches, -1)
# (N, B, S + 1) -> ((N - F) x B, S)
seqs = seqs [search_indices].view(new_num_searches * beam_size, -1)
scores = scores[search_indices].view(new_num_searches * beam_size, -1)
# (N x B, S_enc, M) -> (N, B, S_enc, M)
encoder_output = encoder_output.unflatten(0, (num_searches, -1))
# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
encoder_output = encoder_output[search_indices].flatten(0, 1)
if encoder_padding_mask is not None:
# (N x B, S_enc, M) -> (N, B, S_enc, M)
padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1))
# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
encoder_padding_mask = padding_mask[search_indices].flatten(0, 1)
# fmt: on
num_searches = new_num_searches
else:
search_indices = None
eos_mask[:, :beam_size][ignored_beam_mask] = True
# Set `beam_weights` so that values greater than or equal to 2 x
# `beam_size` indicate finished beams (i.e. end with EOS) and values
# less than 2 x `beam_size` indicate active beams.
# (N, 2 x B)
beam_weights = cand_offsets + (eos_mask * (2 * beam_size))
# Get the top `beam_size` active beams, which are the beams with the
# smallest weights in `active_beam_weights`.
# (N, B)
active_beam_weights, active_beams = torch.topk(
beam_weights, k=beam_size, dim=1, largest=False
)
# Update to ignore finalized beams in the next step.
# (N, B)
ignored_beam_mask = active_beam_weights >= 2 * beam_size
# We should always have at least one active beam in each search.
assert (~ignored_beam_mask).any(dim=1).all()
# Denotes which beams are continued for each new hypothesis (a beam
# can be selected more than once).
# (N, B)
beam_indices = torch.gather(
global_cand_beam_indices, dim=1, index=active_beams
)
# (N, B) -> (N x B)
beam_indices = beam_indices.view(-1)
# fmt: off
# Reorder beams in the `seq` and `score` buffers. The same beam can
# be selected more than once.
if step_nr > start_step:
seqs [:, : step_nr + 1] = torch.index_select(
seqs [:, : step_nr + 1], dim=0, index=beam_indices
)
scores[:, : step_nr + 1] = torch.index_select(
scores[:, : step_nr + 1], dim=0, index=beam_indices
)
# (N x B, S) -> (N, B, S)
seqs_view = seqs .view(num_searches, beam_size, -1)
scores_view = scores.view(num_searches, beam_size, -1)
seqs_view [:, :, step_nr + 1] = torch.gather(cand_indices, dim=1, index=active_beams)
scores_view[:, :, step_nr + 1] = torch.gather(cand_scores, dim=1, index=active_beams)
# fmt: on
# Ensure that hypotheses are sorted by their scores before returning.
for batch in finished_searches:
batch.sort(key=lambda b: b.score, reverse=True) # type: ignore[arg-type, return-value]
return SequenceGeneratorOutput(
results=finished_searches, device=device, collater=self.collater
)
def _determine_max_seq_len(self, source_seq_len: Optional[int]) -> int:
opts = self.opts
if source_seq_len is None or opts.soft_max_seq_len is None:
max_seq_len = opts.hard_max_seq_len
else:
at, bt = opts.soft_max_seq_len
max_seq_len = min(opts.hard_max_seq_len, int(at * source_seq_len + bt))
if opts.min_seq_len > max_seq_len:
raise ValueError(
f"The effective maximum sequence length must be greater than or equal to `min_seq_len` ({opts.min_seq_len}), but is {max_seq_len} instead. Adjust your soft and hard maximum sequence length limits."
)
if self.prefix_seq_len >= max_seq_len:
raise ValueError(
f"The effective maximum sequence length must be greater than `prefix_seq_len` ({self.prefix_seq_len}), but is {max_seq_len} instead."
)
return max_seq_len
def _fan_out_encoder_output(
self, encoder_output: Tensor, encoder_padding_mask: Optional[Tensor]
) -> Tuple[Tensor, Optional[Tensor]]:
num_searches = encoder_output.size(0) # i.e. batch size
# Fan out `encoder_output` to `num_searches` x `beam_size`.
# (N)
fan_out_indices = torch.arange(num_searches, device=encoder_output.device)
# (N) -> (N x B)
fan_out_indices = fan_out_indices.repeat_interleave(self.beam_size)
# (N, S_enc, M) -> (N x B, S_enc, M)
encoder_output = encoder_output.index_select(dim=0, index=fan_out_indices)
# (N, S_enc, M) -> (N x B, S_enc, M)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask.index_select(
dim=0, index=fan_out_indices
)
return encoder_output, encoder_padding_mask
def _bootstrap_seqs_and_scores(
self,
seqs: Tensor,
scores: Tensor,
encoder_output: Tensor,
encoder_padding_mask: Optional[Tensor],
state_bag: IncrementalStateBag,
) -> None:
assert self.prefix_seq_len > 0
seqs[:, : self.prefix_seq_len] = self.prefix_seq
if self.prefix_seq_len == 1:
return
assert isinstance(self.prefix_seq, Tensor)
# We have to bootstrap the model with the already fanned-out encoder
# output to correctly initialize its incremental state. This causes some
# redundancy as we have to expand `decoder_input` to match the shape of
# `encoder_output`.
# (S_pfx) -> (N x B, S_pfx - 1)
decoder_input = self.prefix_seq[:-1].expand(encoder_output.size(0), -1)
# Bootstrap the model state with prefix sequence.
decoder_output, decoder_padding_mask = self.decoder.decode(
decoder_input,
None,
encoder_output,
encoder_padding_mask,
state_bag,
)
state_bag.increment_step(self.prefix_seq_len - 1)
model_output = self.decoder.project(decoder_output, decoder_padding_mask)
# lprobs: (S_pfx - 1, V)
# model_output: (N, S_pfx - 1, V) -> (S_pfx - 1, V)
lprobs = log_softmax(model_output.logits[0], dim=-1, dtype=torch.float32)
# Fetch scores of next steps.
# (S_pfx - 1, 1)
prefix_scores = torch.take_along_dim(
lprobs, indices=self.prefix_seq[1:].unsqueeze(1), dim=-1
)
# (S_pfx - 1, 1) -> (S_pfx - 1)
prefix_scores.squeeze_(1).cumsum_(dim=0)
# First step (e.g. EOS)'s score is always 0.
scores[:, 1 : self.prefix_seq_len] = prefix_scores
def _finalize_hypothesis(
self,
step_nr: int,
eos_beam_indices: Tensor,
eos_scores: Tensor,
seqs: Tensor,
scores: Tensor,
active_searches: List[Tuple[int, List["Hypothesis"]]],
finished_searches: List[List["Hypothesis"]],
) -> List[int]:
# fmt: off
finalized_seqs = seqs .index_select(dim=0, index=eos_beam_indices)
finalized_scores = scores.index_select(dim=0, index=eos_beam_indices)
finalized_seqs = finalized_seqs [:, : step_nr + 2]
finalized_scores = finalized_scores[:, : step_nr + 2]
# Finalize beams.
finalized_seqs [:, -1] = self.eos_idx
finalized_scores[:, -1] = eos_scores
# fmt: on
# Convert from cumulative to per-step scores.
finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
# Skip first EOS since it is always 0 and skews normalization.
if self.opts.normalize_scores:
eos_scores /= (step_nr + 1) ** self.opts.len_penalty
# Holds the ids of finished searches.
newly_finished: List[int] = []
active_search_indices = (eos_beam_indices // self.beam_size).tolist()
for beam_idx, search_idx in enumerate(active_search_indices):
search_id, hypotheses = active_searches[search_idx]
# We might have more than one beam finalized in one step that would
# potentially exceed `beam_size` hypotheses.
if len(hypotheses) == self.beam_size:
continue
hypotheses.append(
Hypothesis(
seq=finalized_seqs[beam_idx],
score=eos_scores[beam_idx],
step_scores=finalized_scores[beam_idx],
)
)
if len(hypotheses) == self.beam_size:
# We have `beam_size` hypotheses for this particular search, so
# we finish it now.
newly_finished.append(search_idx)
finished_searches[search_id] = hypotheses
newly_finished.sort()
# Remove finished searches from the active list.
for idx in reversed(newly_finished):
del active_searches[idx]
return newly_finished
@dataclass
class SequenceGeneratorOutput:
"""Holds the output of a sequence generator."""
results: List[List["Hypothesis"]]
"""The list of hypothesis generated per search, ordered by score."""
device: Device
"""The device on which generated sequences reside."""
collater: Optional[Collater] = None
"""The collater to use in :meth:`collate`."""
def collate(
self, hypo_idx: int = 0, skip_batch: bool = False
) -> Tuple[Tensor, Optional[Tensor]]:
"""Collate the generated sequences at index ``hypo_idx`` in each search
result into a single tensor.
:param hypo_idx:
The index of hypothesis to extract from each search result.
:param skip_batch:
If ``True``, if a search result has no hypothesis at index `hypo_idx`,
it will be skipped instead of raising an error.
:returns:
- The collated sequences. *Shape:* :math:`(N,S)`, where :math:`N` is
the number of search results and :math:`S` is the sequence length.
- An array where each element represents the length of the sequence at
the same index in the first returned value. *Shape:* :math:`(N)`,
where :math:`N` is the number of search results.
"""
if self.collater is None:
raise RuntimeError("The output has no associated `Collater` instance.")
if not self.results and not skip_batch:
raise ValueError("The output must contain at least one search result.")
seqs = []
for search_idx, result in enumerate(self.results):
if hypo_idx >= len(result):
if not skip_batch:
raise ValueError(
f"Each search result must have at least {hypo_idx + 1} hypotheses, but search {search_idx} has only {len(result)}."
)
continue
seqs.append(result[hypo_idx].seq)
if not seqs:
# Return a zero-dimensional (not scalar!) tensor.
return torch.empty((0,), device=self.device, dtype=torch.int64), None
output = cast(SequenceData, self.collater(seqs))
return output["seqs"], output["seq_lens"] if output["is_ragged"] else None
@dataclass
class Hypothesis:
"""Represents a hypothesis produced by a sequence generator."""
seq: Tensor
"""The generated sequence."""
score: Tensor
"""The score of the hypothesis."""
step_scores: Tensor
"""The score of each individual sequence step."""

View File

@ -1,6 +0,0 @@
#k2
kaldialign
lhotse
sentencepiece
tensorboard
fairseq2

View File

@ -1,43 +0,0 @@
#import sentencepiece as spm
class CharTokenizer(object):
def __init__(self, tokenizer_file):
self.id2symbol = {}
self.symbol2id = {}
with open(tokenizer_file, 'r') as f:
for line in f:
line = line.strip()
if line:
symbol, id = line.split()
id = int(id)
self.id2symbol[id] = symbol
self.symbol2id[symbol] = id
self.vocab_size = len(self.id2symbol)
def encode(self, text):
# if symbol not in self.symbol2id, using <unk>'s id
return [self.symbol2id.get(symbol, 2) for symbol in text]
def decode(self, ids):
return ''.join([self.id2symbol[id] for id in ids])
if __name__ == '__main__':
# config_file = './config.yaml'
# config = read_yaml(config_file)
# converter = TokenIDConverter(config['token_list'])
# ids = converter.tokens2ids(['<s>', '你', '好', '吗', '</s>', 'microsoft', 'world'])
# print(ids)
# print(converter.ids2tokens(ids))
tokenizer = CharTokenizer('./tokens.txt')
ids = tokenizer.encode('今天 天气不错')
print(ids)
print(tokenizer.decode(ids+[1]))
# sp = spm.SentencePieceProcessor()
# sp.Load('../../../librispeech/ASR/k2fsa-zipformer-chinese-english-mixed/data/lang_char_bpe/bpe.model')
# texts = ['MICROSOFT WORLD']
# y = sp.encode(texts, out_type=int)
# x = sp.decode(y)
# print(y, x)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -30,7 +30,7 @@ from lhotse.dataset import (
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -176,13 +176,13 @@ class AishellAsrDataModule:
group.add_argument(
"--enable-musan",
type=str2bool,
default=False,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, rank = None, world_size = None
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
"""
Args:
@ -192,13 +192,13 @@ class AishellAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -276,12 +276,10 @@ class AishellAsrDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
@ -302,7 +300,7 @@ class AishellAsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet, rank = None, world_size = None) -> DataLoader:
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
@ -327,8 +325,6 @@ class AishellAsrDataModule:
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
rank=rank,
world_size=world_size,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

View File

@ -473,10 +473,11 @@ def main():
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"]
test_dls = [test_dl]
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
#test_sets = ["test"]
#test_dls = [test_dl]
test_sets = ["valid"]
test_dls = [valid_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,

View File

@ -27,7 +27,7 @@
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 1000
"warmup_num_steps": 100
}
},
"gradient_accumulation_steps": 1,

View File

@ -126,7 +126,7 @@ def get_parser():
parser.add_argument(
"--num-epochs",
type=int,
default=5,
default=10,
help="Number of epochs to train.",
)