mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
import torch
|
|
from seamless_communication.models.inference import Translator
|
|
from seamless_communication.models.unity import (
|
|
UnitTokenizer,
|
|
UnitYModel,
|
|
load_unity_model,
|
|
load_unity_text_tokenizer,
|
|
load_unity_unit_tokenizer,
|
|
)
|
|
from fairseq2.generation import (
|
|
Seq2SeqGenerator,
|
|
SequenceGeneratorOptions,
|
|
SequenceGeneratorOutput,
|
|
SequenceToTextGenerator,
|
|
SequenceToTextOutput,
|
|
)
|
|
from seamless_communication.models.unity.model import UnitYModel, UnitYX2TModel
|
|
|
|
import torchaudio
|
|
import torchaudio.compliance.kaldi as ta_kaldi
|
|
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_en/wav/1089-134686-0001.wav"
|
|
src_lang="cmn"
|
|
|
|
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_zh/wav/long.wav"
|
|
src_lang="eng"
|
|
target_lang = "cmn"
|
|
|
|
audio_input = torchaudio.load(audio_file)[0]
|
|
feature = ta_kaldi.fbank(audio_input, num_mel_bins=80)
|
|
# feature shape is (T, F), convert it to (B, T, F), source_seq_lens tracks T
|
|
source_seqs = feature.unsqueeze(0)
|
|
source_seq_lens = torch.tensor([feature.shape[0]])
|
|
|
|
# Initialize a Translator object with a multitask model, vocoder on the GPU.
|
|
|
|
|
|
# translator = Translator("seamlessM4T_medium", vocoder_name_or_card="vocoder_36langs", device=torch.device("cuda:2"), dtype=torch.float16)
|
|
|
|
# transcribed_text, _, _ = translator.predict(audio_file, "asr", src_lang)
|
|
|
|
# print(transcribed_text)
|
|
|
|
|
|
model_name_or_card = "seamlessM4T_medium"
|
|
device = torch.device("cuda:3")
|
|
|
|
# cast source_seq_lens, source_seqs to device, dtype to torch.float16
|
|
source_seq_lens = source_seq_lens.to(device=device, dtype=torch.float16)
|
|
source_seqs = source_seqs.to(device=device, dtype=torch.float16)
|
|
|
|
|
|
|
|
dtype = torch.float16
|
|
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
|
model.eval()
|
|
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
|
print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx)
|
|
text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target")
|
|
text_tokenizer_decoder = text_tokenizer.create_decoder()
|
|
# print attritbut of text_tokenizer_encoder
|
|
|
|
print(text_tokenizer_encoder("<eos>"))
|
|
print(text_tokenizer_decoder(torch.tensor([3,45])))
|
|
exit(0)
|
|
|
|
|
|
|
|
# def decode(
|
|
# self,
|
|
# seqs: Tensor,
|
|
# seq_lens: Optional[Tensor],
|
|
# encoder_output: Tensor,
|
|
# encoder_padding_mask: Optional[Tensor],
|
|
# state_bag: Optional[IncrementalStateBag] = None,
|
|
# ) -> Tuple[Tensor, Optional[Tensor]]:
|
|
# seqs, padding_mask = self.text_decoder_frontend(seqs, seq_lens, state_bag)
|
|
|
|
# return self.text_decoder( # type: ignore[no-any-return]
|
|
# seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
|
|
# )
|
|
|
|
# def decoding(model, feature):
|
|
# seqs, padding_mask = model.speech_encoder_frontend(seqs, seq_lens)
|
|
# speech_encoder(seqs, padding_mask)
|
|
|
|
# decoder_output, decoder_padding_mask = self.decode(
|
|
# batch.target_seqs,
|
|
# batch.target_seq_lens,
|
|
# encoder_output,
|
|
# encoder_padding_mask,
|
|
# )
|
|
|
|
# text_logits = model.final_project(decoder_output, decoder_padding_mask)
|
|
|
|
text_max_len_a = 1
|
|
text_max_len_b = 200
|
|
|
|
text_opts = SequenceGeneratorOptions(
|
|
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
|
)
|
|
|
|
s2t_model = UnitYX2TModel(
|
|
encoder_frontend=model.speech_encoder_frontend,
|
|
encoder=model.speech_encoder,
|
|
decoder_frontend=model.text_decoder_frontend,
|
|
decoder=model.text_decoder,
|
|
final_proj=model.final_proj,
|
|
pad_idx=model.pad_idx,
|
|
)
|
|
s2t_generator = SequenceToTextGenerator(
|
|
s2t_model, text_tokenizer, target_lang, text_opts
|
|
)
|
|
|
|
text_output = s2t_generator.generate_ex(source_seqs, source_seq_lens)
|
|
sentence = text_output.sentences[0]
|
|
print(sentence, type(sentence))
|
|
sentence = sentence.bytes().decode("utf-8")
|