mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
removed the erroneous ``continual'' implementation
This commit is contained in:
parent
ab91112909
commit
d670f3b8dd
@ -118,13 +118,6 @@ def get_args():
|
|||||||
help="The temperature of AR Decoder top_k sampling.",
|
help="The temperature of AR Decoder top_k sampling.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--continual",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Do continual task.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repetition-aware-sampling",
|
"--repetition-aware-sampling",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -262,29 +255,21 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# synthesis
|
# synthesis
|
||||||
if args.continual:
|
enroll_x_lens = None
|
||||||
assert text == ""
|
if text_prompts:
|
||||||
encoded_frames = model.continual(
|
_, enroll_x_lens = text_collater(
|
||||||
text_tokens.to(device),
|
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
|
||||||
text_tokens_lens.to(device),
|
|
||||||
audio_prompts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
enroll_x_lens = None
|
|
||||||
if text_prompts:
|
|
||||||
_, enroll_x_lens = text_collater(
|
|
||||||
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
|
|
||||||
)
|
|
||||||
encoded_frames = model.inference(
|
|
||||||
text_tokens.to(device),
|
|
||||||
text_tokens_lens.to(device),
|
|
||||||
audio_prompts,
|
|
||||||
enroll_x_lens=enroll_x_lens,
|
|
||||||
top_k=args.top_k,
|
|
||||||
temperature=args.temperature,
|
|
||||||
top_p=args.top_p,
|
|
||||||
ras=args.repetition_aware_sampling,
|
|
||||||
)
|
)
|
||||||
|
encoded_frames = model.inference(
|
||||||
|
text_tokens.to(device),
|
||||||
|
text_tokens_lens.to(device),
|
||||||
|
audio_prompts,
|
||||||
|
enroll_x_lens=enroll_x_lens,
|
||||||
|
top_k=args.top_k,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
ras=args.repetition_aware_sampling,
|
||||||
|
)
|
||||||
|
|
||||||
if audio_prompts != []:
|
if audio_prompts != []:
|
||||||
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
|
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
|
||||||
|
@ -1564,103 +1564,6 @@ class VALLE(nn.Module):
|
|||||||
assert len(codes) == self.num_quantizers
|
assert len(codes) == self.num_quantizers
|
||||||
return torch.stack(codes, dim=-1)
|
return torch.stack(codes, dim=-1)
|
||||||
|
|
||||||
def continual(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
x_lens: torch.Tensor,
|
|
||||||
y: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
A 2-D tensor of shape (1, S).
|
|
||||||
x_lens:
|
|
||||||
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
|
||||||
before padding.
|
|
||||||
y:
|
|
||||||
A 3-D tensor of shape (1, T, 8).
|
|
||||||
Returns:
|
|
||||||
Return the predicted audio code matrix.
|
|
||||||
"""
|
|
||||||
assert x.ndim == 2, x.shape
|
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
|
||||||
assert y.ndim == 3, y.shape
|
|
||||||
assert y.shape[0] == 1, y.shape
|
|
||||||
|
|
||||||
assert torch.all(x_lens > 0)
|
|
||||||
assert self.num_quantizers == 8
|
|
||||||
|
|
||||||
# NOTE: x has been padded in TextTokenCollater
|
|
||||||
text = x
|
|
||||||
x = self.ar_text_embedding(text)
|
|
||||||
x = self.ar_text_prenet(x)
|
|
||||||
x = self.ar_text_position(x)
|
|
||||||
|
|
||||||
text_len = x_lens.max()
|
|
||||||
|
|
||||||
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
|
|
||||||
|
|
||||||
# AR Decoder
|
|
||||||
prompts = y[:, :prefix_len]
|
|
||||||
|
|
||||||
codes = [y[:, prefix_len:, 0]]
|
|
||||||
# Non-AR Decoders
|
|
||||||
x = self.nar_text_embedding(text)
|
|
||||||
x = self.nar_text_prenet(x)
|
|
||||||
x = self.nar_text_position(x)
|
|
||||||
|
|
||||||
y_emb = self.nar_audio_embeddings[0](y[..., 0])
|
|
||||||
|
|
||||||
if self.prefix_mode == 0:
|
|
||||||
for i, (predict_layer, embedding_layer) in enumerate(
|
|
||||||
zip(
|
|
||||||
self.nar_predict_layers,
|
|
||||||
self.nar_audio_embeddings[1:],
|
|
||||||
)
|
|
||||||
):
|
|
||||||
y_pos = self.nar_audio_position(y_emb)
|
|
||||||
y_pos = self.nar_audio_prenet(y_pos)
|
|
||||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
|
||||||
|
|
||||||
xy_dec, _ = self.nar_decoder(
|
|
||||||
(xy_pos, self.nar_stage_embeddings[i].weight)
|
|
||||||
)
|
|
||||||
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
|
||||||
|
|
||||||
samples = torch.argmax(logits, dim=-1)
|
|
||||||
codes.append(samples)
|
|
||||||
|
|
||||||
if i < 6:
|
|
||||||
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
|
||||||
y_emb[:, prefix_len:] += embedding_layer(samples)
|
|
||||||
else:
|
|
||||||
for j in range(1, 8):
|
|
||||||
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
|
||||||
|
|
||||||
for i, (predict_layer, embedding_layer) in enumerate(
|
|
||||||
zip(
|
|
||||||
self.nar_predict_layers,
|
|
||||||
self.nar_audio_embeddings[1:],
|
|
||||||
)
|
|
||||||
):
|
|
||||||
y_pos = self.nar_audio_prenet(y_emb)
|
|
||||||
y_pos = self.nar_audio_position(y_pos)
|
|
||||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
|
||||||
|
|
||||||
xy_dec, _ = self.nar_decoder(
|
|
||||||
(xy_pos, self.nar_stage_embeddings[i].weight)
|
|
||||||
)
|
|
||||||
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
|
||||||
|
|
||||||
samples = torch.argmax(logits, dim=-1)
|
|
||||||
codes.append(samples)
|
|
||||||
|
|
||||||
if i < 6:
|
|
||||||
y_emb[:, prefix_len:] += embedding_layer(samples)
|
|
||||||
|
|
||||||
assert len(codes) == 8
|
|
||||||
return torch.stack(codes, dim=-1)
|
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
self,
|
self,
|
||||||
predicts: Tuple[torch.Tensor],
|
predicts: Tuple[torch.Tensor],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user