mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
update finetuning codes
This commit is contained in:
parent
f99f4d7c92
commit
363c3f1f82
4
egs/aishell/ASR/run.sh
Normal file
4
egs/aishell/ASR/run.sh
Normal file
@ -0,0 +1,4 @@
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="2,3"
|
||||
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall
|
||||
torchrun --nproc-per-node 2 seamlessm4t/train2.py --use-fp16 1 --max-duration 20
|
1
egs/aishell/ASR/seamlessm4t/asr_datamodule.py
Symbolic link
1
egs/aishell/ASR/seamlessm4t/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../tdnn_lstm_ctc/asr_datamodule.py
|
1
egs/aishell/ASR/seamlessm4t/label_smoothing.py
Symbolic link
1
egs/aishell/ASR/seamlessm4t/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
117
egs/aishell/ASR/seamlessm4t/model.py
Normal file
117
egs/aishell/ASR/seamlessm4t/model.py
Normal file
@ -0,0 +1,117 @@
|
||||
import torch
|
||||
from seamless_communication.models.inference import Translator
|
||||
from seamless_communication.models.unity import (
|
||||
UnitTokenizer,
|
||||
UnitYModel,
|
||||
load_unity_model,
|
||||
load_unity_text_tokenizer,
|
||||
load_unity_unit_tokenizer,
|
||||
)
|
||||
from fairseq2.generation import (
|
||||
Seq2SeqGenerator,
|
||||
SequenceGeneratorOptions,
|
||||
SequenceGeneratorOutput,
|
||||
SequenceToTextGenerator,
|
||||
SequenceToTextOutput,
|
||||
)
|
||||
from seamless_communication.models.unity.model import UnitYModel, UnitYX2TModel
|
||||
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_en/wav/1089-134686-0001.wav"
|
||||
src_lang="cmn"
|
||||
|
||||
audio_file="/mnt/samsung-t7/yuekai/asr/Triton-ASR-Client/datasets/mini_zh/wav/long.wav"
|
||||
src_lang="eng"
|
||||
target_lang = "cmn"
|
||||
|
||||
audio_input = torchaudio.load(audio_file)[0]
|
||||
feature = ta_kaldi.fbank(audio_input, num_mel_bins=80)
|
||||
# feature shape is (T, F), convert it to (B, T, F), source_seq_lens tracks T
|
||||
source_seqs = feature.unsqueeze(0)
|
||||
source_seq_lens = torch.tensor([feature.shape[0]])
|
||||
|
||||
# Initialize a Translator object with a multitask model, vocoder on the GPU.
|
||||
|
||||
|
||||
# translator = Translator("seamlessM4T_medium", vocoder_name_or_card="vocoder_36langs", device=torch.device("cuda:2"), dtype=torch.float16)
|
||||
|
||||
# transcribed_text, _, _ = translator.predict(audio_file, "asr", src_lang)
|
||||
|
||||
# print(transcribed_text)
|
||||
|
||||
|
||||
model_name_or_card = "seamlessM4T_medium"
|
||||
device = torch.device("cuda:3")
|
||||
|
||||
# cast source_seq_lens, source_seqs to device, dtype to torch.float16
|
||||
source_seq_lens = source_seq_lens.to(device=device, dtype=torch.float16)
|
||||
source_seqs = source_seqs.to(device=device, dtype=torch.float16)
|
||||
|
||||
|
||||
|
||||
dtype = torch.float16
|
||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
||||
model.eval()
|
||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx)
|
||||
text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target")
|
||||
text_tokenizer_decoder = text_tokenizer.create_decoder()
|
||||
# print attritbut of text_tokenizer_encoder
|
||||
|
||||
print(text_tokenizer_encoder("<eos>"))
|
||||
print(text_tokenizer_decoder(torch.tensor([3,45])))
|
||||
exit(0)
|
||||
|
||||
|
||||
|
||||
# def decode(
|
||||
# self,
|
||||
# seqs: Tensor,
|
||||
# seq_lens: Optional[Tensor],
|
||||
# encoder_output: Tensor,
|
||||
# encoder_padding_mask: Optional[Tensor],
|
||||
# state_bag: Optional[IncrementalStateBag] = None,
|
||||
# ) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
# seqs, padding_mask = self.text_decoder_frontend(seqs, seq_lens, state_bag)
|
||||
|
||||
# return self.text_decoder( # type: ignore[no-any-return]
|
||||
# seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
|
||||
# )
|
||||
|
||||
# def decoding(model, feature):
|
||||
# seqs, padding_mask = model.speech_encoder_frontend(seqs, seq_lens)
|
||||
# speech_encoder(seqs, padding_mask)
|
||||
|
||||
# decoder_output, decoder_padding_mask = self.decode(
|
||||
# batch.target_seqs,
|
||||
# batch.target_seq_lens,
|
||||
# encoder_output,
|
||||
# encoder_padding_mask,
|
||||
# )
|
||||
|
||||
# text_logits = model.final_project(decoder_output, decoder_padding_mask)
|
||||
|
||||
text_max_len_a = 1
|
||||
text_max_len_b = 200
|
||||
|
||||
text_opts = SequenceGeneratorOptions(
|
||||
beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
||||
)
|
||||
|
||||
s2t_model = UnitYX2TModel(
|
||||
encoder_frontend=model.speech_encoder_frontend,
|
||||
encoder=model.speech_encoder,
|
||||
decoder_frontend=model.text_decoder_frontend,
|
||||
decoder=model.text_decoder,
|
||||
final_proj=model.final_proj,
|
||||
pad_idx=model.pad_idx,
|
||||
)
|
||||
s2t_generator = SequenceToTextGenerator(
|
||||
s2t_model, text_tokenizer, target_lang, text_opts
|
||||
)
|
||||
|
||||
text_output = s2t_generator.generate_ex(source_seqs, source_seq_lens)
|
||||
sentence = text_output.sentences[0]
|
||||
print(sentence, type(sentence))
|
||||
sentence = sentence.bytes().decode("utf-8")
|
1173
egs/aishell/ASR/seamlessm4t/optim.py
Normal file
1173
egs/aishell/ASR/seamlessm4t/optim.py
Normal file
File diff suppressed because it is too large
Load Diff
1254
egs/aishell/ASR/seamlessm4t/train2.py
Normal file
1254
egs/aishell/ASR/seamlessm4t/train2.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -22,7 +22,7 @@ from torch import distributed as dist
|
||||
|
||||
|
||||
def setup_dist(
|
||||
rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None
|
||||
rank=None, world_size=None, master_port=None, use_ddp_launch=False, master_addr=None
|
||||
):
|
||||
"""
|
||||
rank and world_size are used only if use_ddp_launch is False.
|
||||
|
Loading…
x
Reference in New Issue
Block a user