mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
removed the erroneous ‘’continual'' implementation (#1865)
This commit is contained in:
parent
8ab0352e60
commit
79074ef0d4
@ -118,13 +118,6 @@ def get_args():
|
||||
help="The temperature of AR Decoder top_k sampling.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--continual",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Do continual task.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repetition-aware-sampling",
|
||||
type=str2bool,
|
||||
@ -262,29 +255,21 @@ def main():
|
||||
)
|
||||
|
||||
# synthesis
|
||||
if args.continual:
|
||||
assert text == ""
|
||||
encoded_frames = model.continual(
|
||||
text_tokens.to(device),
|
||||
text_tokens_lens.to(device),
|
||||
audio_prompts,
|
||||
)
|
||||
else:
|
||||
enroll_x_lens = None
|
||||
if text_prompts:
|
||||
_, enroll_x_lens = text_collater(
|
||||
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
|
||||
)
|
||||
encoded_frames = model.inference(
|
||||
text_tokens.to(device),
|
||||
text_tokens_lens.to(device),
|
||||
audio_prompts,
|
||||
enroll_x_lens=enroll_x_lens,
|
||||
top_k=args.top_k,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
ras=args.repetition_aware_sampling,
|
||||
enroll_x_lens = None
|
||||
if text_prompts:
|
||||
_, enroll_x_lens = text_collater(
|
||||
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
|
||||
)
|
||||
encoded_frames = model.inference(
|
||||
text_tokens.to(device),
|
||||
text_tokens_lens.to(device),
|
||||
audio_prompts,
|
||||
enroll_x_lens=enroll_x_lens,
|
||||
top_k=args.top_k,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
ras=args.repetition_aware_sampling,
|
||||
)
|
||||
|
||||
if audio_prompts != []:
|
||||
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
|
||||
|
@ -1564,103 +1564,6 @@ class VALLE(nn.Module):
|
||||
assert len(codes) == self.num_quantizers
|
||||
return torch.stack(codes, dim=-1)
|
||||
|
||||
def continual(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor of shape (1, S).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
||||
before padding.
|
||||
y:
|
||||
A 3-D tensor of shape (1, T, 8).
|
||||
Returns:
|
||||
Return the predicted audio code matrix.
|
||||
"""
|
||||
assert x.ndim == 2, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.ndim == 3, y.shape
|
||||
assert y.shape[0] == 1, y.shape
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
assert self.num_quantizers == 8
|
||||
|
||||
# NOTE: x has been padded in TextTokenCollater
|
||||
text = x
|
||||
x = self.ar_text_embedding(text)
|
||||
x = self.ar_text_prenet(x)
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
text_len = x_lens.max()
|
||||
|
||||
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
|
||||
|
||||
# AR Decoder
|
||||
prompts = y[:, :prefix_len]
|
||||
|
||||
codes = [y[:, prefix_len:, 0]]
|
||||
# Non-AR Decoders
|
||||
x = self.nar_text_embedding(text)
|
||||
x = self.nar_text_prenet(x)
|
||||
x = self.nar_text_position(x)
|
||||
|
||||
y_emb = self.nar_audio_embeddings[0](y[..., 0])
|
||||
|
||||
if self.prefix_mode == 0:
|
||||
for i, (predict_layer, embedding_layer) in enumerate(
|
||||
zip(
|
||||
self.nar_predict_layers,
|
||||
self.nar_audio_embeddings[1:],
|
||||
)
|
||||
):
|
||||
y_pos = self.nar_audio_position(y_emb)
|
||||
y_pos = self.nar_audio_prenet(y_pos)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
xy_dec, _ = self.nar_decoder(
|
||||
(xy_pos, self.nar_stage_embeddings[i].weight)
|
||||
)
|
||||
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
||||
|
||||
samples = torch.argmax(logits, dim=-1)
|
||||
codes.append(samples)
|
||||
|
||||
if i < 6:
|
||||
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
||||
y_emb[:, prefix_len:] += embedding_layer(samples)
|
||||
else:
|
||||
for j in range(1, 8):
|
||||
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
||||
|
||||
for i, (predict_layer, embedding_layer) in enumerate(
|
||||
zip(
|
||||
self.nar_predict_layers,
|
||||
self.nar_audio_embeddings[1:],
|
||||
)
|
||||
):
|
||||
y_pos = self.nar_audio_prenet(y_emb)
|
||||
y_pos = self.nar_audio_position(y_pos)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
xy_dec, _ = self.nar_decoder(
|
||||
(xy_pos, self.nar_stage_embeddings[i].weight)
|
||||
)
|
||||
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
||||
|
||||
samples = torch.argmax(logits, dim=-1)
|
||||
codes.append(samples)
|
||||
|
||||
if i < 6:
|
||||
y_emb[:, prefix_len:] += embedding_layer(samples)
|
||||
|
||||
assert len(codes) == 8
|
||||
return torch.stack(codes, dim=-1)
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
predicts: Tuple[torch.Tensor],
|
||||
|
Loading…
x
Reference in New Issue
Block a user