From d670f3b8dd9ae1d1fb283aaa875b69aa79f5a7ed Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 16 Jan 2025 11:49:44 +0800 Subject: [PATCH] removed the erroneous ``continual'' implementation --- egs/wenetspeech4tts/TTS/valle/infer.py | 43 ++++-------- egs/wenetspeech4tts/TTS/valle/valle.py | 97 -------------------------- 2 files changed, 14 insertions(+), 126 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/valle/infer.py b/egs/wenetspeech4tts/TTS/valle/infer.py index 44a251c56..d98abb731 100644 --- a/egs/wenetspeech4tts/TTS/valle/infer.py +++ b/egs/wenetspeech4tts/TTS/valle/infer.py @@ -118,13 +118,6 @@ def get_args(): help="The temperature of AR Decoder top_k sampling.", ) - parser.add_argument( - "--continual", - type=str2bool, - default=False, - help="Do continual task.", - ) - parser.add_argument( "--repetition-aware-sampling", type=str2bool, @@ -262,29 +255,21 @@ def main(): ) # synthesis - if args.continual: - assert text == "" - encoded_frames = model.continual( - text_tokens.to(device), - text_tokens_lens.to(device), - audio_prompts, - ) - else: - enroll_x_lens = None - if text_prompts: - _, enroll_x_lens = text_collater( - [tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())] - ) - encoded_frames = model.inference( - text_tokens.to(device), - text_tokens_lens.to(device), - audio_prompts, - enroll_x_lens=enroll_x_lens, - top_k=args.top_k, - temperature=args.temperature, - top_p=args.top_p, - ras=args.repetition_aware_sampling, + enroll_x_lens = None + if text_prompts: + _, enroll_x_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())] ) + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ras=args.repetition_aware_sampling, + ) if audio_prompts != []: samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)]) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index 772317428..8f9b8fc3d 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -1564,103 +1564,6 @@ class VALLE(nn.Module): assert len(codes) == self.num_quantizers return torch.stack(codes, dim=-1) - def continual( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - x: - A 2-D tensor of shape (1, S). - x_lens: - A 1-D tensor of shape (1,). It contains the number of tokens in `x` - before padding. - y: - A 3-D tensor of shape (1, T, 8). - Returns: - Return the predicted audio code matrix. - """ - assert x.ndim == 2, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.ndim == 3, y.shape - assert y.shape[0] == 1, y.shape - - assert torch.all(x_lens > 0) - assert self.num_quantizers == 8 - - # NOTE: x has been padded in TextTokenCollater - text = x - x = self.ar_text_embedding(text) - x = self.ar_text_prenet(x) - x = self.ar_text_position(x) - - text_len = x_lens.max() - - prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) - - # AR Decoder - prompts = y[:, :prefix_len] - - codes = [y[:, prefix_len:, 0]] - # Non-AR Decoders - x = self.nar_text_embedding(text) - x = self.nar_text_prenet(x) - x = self.nar_text_position(x) - - y_emb = self.nar_audio_embeddings[0](y[..., 0]) - - if self.prefix_mode == 0: - for i, (predict_layer, embedding_layer) in enumerate( - zip( - self.nar_predict_layers, - self.nar_audio_embeddings[1:], - ) - ): - y_pos = self.nar_audio_position(y_emb) - y_pos = self.nar_audio_prenet(y_pos) - xy_pos = torch.concat([x, y_pos], dim=1) - - xy_dec, _ = self.nar_decoder( - (xy_pos, self.nar_stage_embeddings[i].weight) - ) - logits = predict_layer(xy_dec[:, text_len + prefix_len :]) - - samples = torch.argmax(logits, dim=-1) - codes.append(samples) - - if i < 6: - y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) - y_emb[:, prefix_len:] += embedding_layer(samples) - else: - for j in range(1, 8): - y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) - - for i, (predict_layer, embedding_layer) in enumerate( - zip( - self.nar_predict_layers, - self.nar_audio_embeddings[1:], - ) - ): - y_pos = self.nar_audio_prenet(y_emb) - y_pos = self.nar_audio_position(y_pos) - xy_pos = torch.concat([x, y_pos], dim=1) - - xy_dec, _ = self.nar_decoder( - (xy_pos, self.nar_stage_embeddings[i].weight) - ) - logits = predict_layer(xy_dec[:, text_len + prefix_len :]) - - samples = torch.argmax(logits, dim=-1) - codes.append(samples) - - if i < 6: - y_emb[:, prefix_len:] += embedding_layer(samples) - - assert len(codes) == 8 - return torch.stack(codes, dim=-1) - def visualize( self, predicts: Tuple[torch.Tensor],