mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add repeat_param to embedding_out during training and decoding
This commit is contained in:
parent
45b68e1df6
commit
c3e01e141f
@ -542,7 +542,9 @@ def greedy_search(
|
|||||||
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
|
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
|
||||||
).reshape(1, context_size)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
k = torch.zeros(1, 1, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, k, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
@ -586,7 +588,20 @@ def greedy_search(
|
|||||||
1, context_size
|
1, context_size
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape(
|
||||||
|
1, context_size + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
k[:, 0] = torch.sum(
|
||||||
|
(
|
||||||
|
c[:, -context_size - 1 : -1]
|
||||||
|
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
|
||||||
|
).int(),
|
||||||
|
dim=1,
|
||||||
|
keepdim=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, k, need_pad=False)
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
sym_per_utt += 1
|
sym_per_utt += 1
|
||||||
@ -594,7 +609,7 @@ def greedy_search(
|
|||||||
else:
|
else:
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
t += 1
|
t += 1
|
||||||
hyp = hyp[context_size:] # remove blanks
|
hyp = hyp[context_size :] # remove blanks
|
||||||
|
|
||||||
if not return_timestamps:
|
if not return_timestamps:
|
||||||
return hyp
|
return hyp
|
||||||
|
@ -20,36 +20,36 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./pruned_transducer_stateless9/decode.py \
|
||||||
--epoch 28 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 8 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./pruned_transducer_stateless9/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./pruned_transducer_stateless9/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search (one best)
|
(4) fast beam search (one best)
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./pruned_transducer_stateless9/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
@ -57,10 +57,10 @@ Usage:
|
|||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
(5) fast beam search (nbest)
|
(5) fast beam search (nbest)
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./pruned_transducer_stateless9/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search_nbest \
|
--decoding-method fast_beam_search_nbest \
|
||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
@ -70,10 +70,10 @@ Usage:
|
|||||||
--nbest-scale 0.5
|
--nbest-scale 0.5
|
||||||
|
|
||||||
(6) fast beam search (nbest oracle WER)
|
(6) fast beam search (nbest oracle WER)
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./pruned_transducer_stateless9/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search_nbest_oracle \
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
@ -83,10 +83,10 @@ Usage:
|
|||||||
--nbest-scale 0.5
|
--nbest-scale 0.5
|
||||||
|
|
||||||
(7) fast beam search (with LG)
|
(7) fast beam search (with LG)
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./pruned_transducer_stateless9/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search_nbest_LG \
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
@ -223,7 +223,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless7/exp",
|
default="pruned_transducer_stateless9/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,14 +73,50 @@ class Decoder(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
self.repeat_param = nn.Parameter(torch.randn(decoder_dim))
|
||||||
|
|
||||||
|
def _add_repeat_param(
|
||||||
|
self,
|
||||||
|
embedding_out: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
is_training: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Add the repeat parameter to the embedding_out tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_out:
|
||||||
|
A tensor of shape (N, U, decoder_dim).
|
||||||
|
k:
|
||||||
|
A tensor of shape (N, U).
|
||||||
|
Should be (N, S + 1) during training.
|
||||||
|
Should be (N, 1) during inference.
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
|
"""
|
||||||
|
return embedding_out + torch.matmul(
|
||||||
|
(k / (1 + k)).unsqueeze(2),
|
||||||
|
self.repeat_param.unsqueeze(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
need_pad: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y:
|
y:
|
||||||
A 2-D tensor of shape (N, U).
|
A 2-D tensor of shape (N, U).
|
||||||
|
k:
|
||||||
|
A 2-D tensor, statistic given the context_size with respect to utt.
|
||||||
|
Should be (N, S + 1) during training.
|
||||||
|
Should be (N, 1) during inference.
|
||||||
need_pad:
|
need_pad:
|
||||||
True to left pad the input. Should be True during training.
|
Whether to left pad the input.
|
||||||
False to not pad the input. Should be False during inference.
|
Should be True during training.
|
||||||
|
Should be False during inference.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, decoder_dim).
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
"""
|
"""
|
||||||
@ -90,7 +126,7 @@ class Decoder(nn.Module):
|
|||||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
@ -98,5 +134,10 @@ class Decoder(nn.Module):
|
|||||||
assert embedding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
|
||||||
|
embedding_out = self._add_repeat_param(
|
||||||
|
embedding_out=embedding_out,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
embedding_out = F.relu(embedding_out)
|
embedding_out = F.relu(embedding_out)
|
||||||
return embedding_out
|
return embedding_out
|
||||||
|
@ -20,10 +20,11 @@ import random
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import penalize_abs_values_gt
|
from scaling import penalize_abs_values_gt
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -75,7 +76,9 @@ class Transducer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: torch.Tensor,
|
||||||
|
y_lens: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
@ -88,8 +91,14 @@ class Transducer(nn.Module):
|
|||||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
before padding.
|
before padding.
|
||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A 2-D tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
|
y_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `y`
|
||||||
|
before padding.
|
||||||
|
k:
|
||||||
|
A statistic given the context_size with respect to utt.
|
||||||
|
A 2-D tensor of shape (N, U).
|
||||||
prune_range:
|
prune_range:
|
||||||
The prune range for rnnt loss, it means how many symbols(context)
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
we are considering for each frame to compute the loss.
|
we are considering for each frame to compute the loss.
|
||||||
@ -110,31 +119,24 @@ class Transducer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert x.ndim == 3, x.shape
|
assert x.ndim == 3, x.shape
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
assert y.num_axes == 2, y.num_axes
|
assert len(y.shape) == 2, len(y.shape)
|
||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.size(0)
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
|
||||||
row_splits = y.shape.row_splits(1)
|
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
|
||||||
|
|
||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
sos_y = add_sos(y, sos_id=blank_id)
|
|
||||||
|
|
||||||
# sos_y_padded: [B, S + 1], start with SOS.
|
# sos_y_padded: [B, S + 1], start with SOS.
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id)
|
||||||
|
|
||||||
# decoder_out: [B, S + 1, decoder_dim]
|
# decoder_out: [B, S + 1, decoder_dim]
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
decoder_out = self.decoder(sos_y_padded, k)
|
||||||
|
|
||||||
# Note: y does not start with SOS
|
# Note: y_padded does not start with SOS
|
||||||
# y_padded : [B, S]
|
# y_padded : [B, S]
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.to(torch.int64)
|
||||||
|
|
||||||
y_padded = y_padded.to(torch.int64)
|
|
||||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
@ -22,22 +22,22 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/train.py \
|
./pruned_transducer_stateless9/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
--exp-dir pruned_transducer_stateless9/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless7/train.py \
|
./pruned_transducer_stateless9/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
--exp-dir pruned_transducer_stateless9/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
@ -58,6 +58,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -69,6 +70,7 @@ from optim import Eden, ScaledAdam
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer
|
||||||
|
|
||||||
@ -235,7 +237,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless7/exp",
|
default="pruned_transducer_stateless9/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -625,6 +627,41 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_k(
|
||||||
|
y: torch.Tensor,
|
||||||
|
context_size: int = 2,
|
||||||
|
blank_id: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, U).
|
||||||
|
context_size:
|
||||||
|
Number of previous words to use to predict the next word.
|
||||||
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, U).
|
||||||
|
"""
|
||||||
|
y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS.
|
||||||
|
k = torch.zeros_like(y)
|
||||||
|
|
||||||
|
for i in range(2, y.size(1) - 1):
|
||||||
|
k[:, i : i + 1] = torch.where(
|
||||||
|
y[:, i : i + 1] != 0,
|
||||||
|
torch.sum(
|
||||||
|
(
|
||||||
|
y[:, i - context_size : i]
|
||||||
|
== y[:, i : i + 1].expand_as(y[:, i - context_size : i])
|
||||||
|
).int(),
|
||||||
|
dim=1,
|
||||||
|
keepdim=True,
|
||||||
|
),
|
||||||
|
y[:, i : i + 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -675,13 +712,26 @@ def compute_loss(
|
|||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y_lens = torch.tensor(list(map(len, y))).to(device)
|
||||||
|
y = list(map(torch.tensor, y))
|
||||||
|
y = pad_sequence(y, batch_first=True) # [B, S]
|
||||||
|
|
||||||
|
k = compute_k(
|
||||||
|
y,
|
||||||
|
params.context_size,
|
||||||
|
model.module.decoder.blank_id
|
||||||
|
if isinstance(model, DDP)
|
||||||
|
else model.decoder.blank_id,
|
||||||
|
).to(device)
|
||||||
|
y = y.to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
|
y_lens=y_lens,
|
||||||
|
k=k,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user