mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use a stateless decoder.
This commit is contained in:
parent
2cf1b56cb3
commit
ec083e93d8
@ -78,7 +78,7 @@ class Decoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y:
|
y:
|
||||||
A 2-D tensor of shape (N, U) with BOS prepended.
|
A 2-D tensor of shape (N, U) with blank prepended.
|
||||||
states:
|
states:
|
||||||
A tuple of two tensors containing the states information of
|
A tuple of two tensors containing the states information of
|
||||||
LSTM layers in this decoder.
|
LSTM layers in this decoder.
|
||||||
|
@ -1,219 +0,0 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from model import Transducer
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
model:
|
|
||||||
An instance of `Transducer`.
|
|
||||||
encoder_out:
|
|
||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
|
||||||
Returns:
|
|
||||||
Return the decoded result.
|
|
||||||
"""
|
|
||||||
assert encoder_out.ndim == 3
|
|
||||||
|
|
||||||
# support only batch_size == 1 for now
|
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
|
||||||
blank_id = model.decoder.blank_id
|
|
||||||
device = model.device
|
|
||||||
|
|
||||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
|
||||||
decoder_out, (h, c) = model.decoder(sos)
|
|
||||||
T = encoder_out.size(1)
|
|
||||||
t = 0
|
|
||||||
hyp = []
|
|
||||||
|
|
||||||
sym_per_frame = 0
|
|
||||||
sym_per_utt = 0
|
|
||||||
|
|
||||||
max_sym_per_utt = 1000
|
|
||||||
max_sym_per_frame = 3
|
|
||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
|
||||||
# fmt: off
|
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
|
||||||
# fmt: on
|
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
|
||||||
# logits is (1, 1, 1, vocab_size)
|
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
|
||||||
# TODO: Use logits.argmax()
|
|
||||||
y = log_prob.argmax()
|
|
||||||
if y != blank_id:
|
|
||||||
hyp.append(y.item())
|
|
||||||
y = y.reshape(1, 1)
|
|
||||||
decoder_out, (h, c) = model.decoder(y, (h, c))
|
|
||||||
|
|
||||||
sym_per_utt += 1
|
|
||||||
sym_per_frame += 1
|
|
||||||
|
|
||||||
if y == blank_id or sym_per_frame > max_sym_per_frame:
|
|
||||||
sym_per_frame = 0
|
|
||||||
t += 1
|
|
||||||
|
|
||||||
return hyp
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Hypothesis:
|
|
||||||
ys: List[int] # the predicted sequences so far
|
|
||||||
log_prob: float # The log prob of ys
|
|
||||||
|
|
||||||
# Optional decoder state. We assume it is LSTM for now,
|
|
||||||
# so the state is a tuple (h, c)
|
|
||||||
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
|
||||||
model: Transducer,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
beam: int = 5,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
|
||||||
|
|
||||||
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model:
|
|
||||||
An instance of `Transducer`.
|
|
||||||
encoder_out:
|
|
||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
|
||||||
beam:
|
|
||||||
Beam size.
|
|
||||||
Returns:
|
|
||||||
Return the decoded result.
|
|
||||||
"""
|
|
||||||
assert encoder_out.ndim == 3
|
|
||||||
|
|
||||||
# support only batch_size == 1 for now
|
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
|
||||||
blank_id = model.decoder.blank_id
|
|
||||||
device = model.device
|
|
||||||
|
|
||||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
|
||||||
decoder_out, (h, c) = model.decoder(sos)
|
|
||||||
T = encoder_out.size(1)
|
|
||||||
t = 0
|
|
||||||
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
|
|
||||||
max_u = 20000 # terminate after this number of steps
|
|
||||||
u = 0
|
|
||||||
|
|
||||||
cache: Dict[
|
|
||||||
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
||||||
] = {}
|
|
||||||
|
|
||||||
while t < T and u < max_u:
|
|
||||||
# fmt: off
|
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
|
||||||
# fmt: on
|
|
||||||
A = B
|
|
||||||
B = []
|
|
||||||
# for hyp in A:
|
|
||||||
# for h in A:
|
|
||||||
# if h.ys == hyp.ys[:-1]:
|
|
||||||
# # update the score of hyp
|
|
||||||
# decoder_input = torch.tensor(
|
|
||||||
# [h.ys[-1]], device=device
|
|
||||||
# ).reshape(1, 1)
|
|
||||||
# decoder_out, _ = model.decoder(
|
|
||||||
# decoder_input, h.decoder_state
|
|
||||||
# )
|
|
||||||
# logits = model.joiner(current_encoder_out, decoder_out)
|
|
||||||
# log_prob = logits.log_softmax(dim=-1)
|
|
||||||
# log_prob = log_prob.squeeze()
|
|
||||||
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
|
|
||||||
|
|
||||||
while u < max_u:
|
|
||||||
y_star = max(A, key=lambda hyp: hyp.log_prob)
|
|
||||||
A.remove(y_star)
|
|
||||||
|
|
||||||
# Note: y_star.ys is unhashable, i.e., cannot be used
|
|
||||||
# as a key into a dict
|
|
||||||
cached_key = "_".join(map(str, y_star.ys))
|
|
||||||
|
|
||||||
if cached_key not in cache:
|
|
||||||
decoder_input = torch.tensor(
|
|
||||||
[y_star.ys[-1]], device=device
|
|
||||||
).reshape(1, 1)
|
|
||||||
|
|
||||||
decoder_out, decoder_state = model.decoder(
|
|
||||||
decoder_input,
|
|
||||||
y_star.decoder_state,
|
|
||||||
)
|
|
||||||
cache[cached_key] = (decoder_out, decoder_state)
|
|
||||||
else:
|
|
||||||
decoder_out, decoder_state = cache[cached_key]
|
|
||||||
|
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
|
||||||
log_prob = log_prob.squeeze()
|
|
||||||
# Now log_prob is (vocab_size,)
|
|
||||||
|
|
||||||
# If we choose blank here, add the new hypothesis to B.
|
|
||||||
# Otherwise, add the new hypothesis to A
|
|
||||||
|
|
||||||
# First, choose blank
|
|
||||||
skip_log_prob = log_prob[blank_id]
|
|
||||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
|
||||||
|
|
||||||
# ys[:] returns a copy of ys
|
|
||||||
new_y_star = Hypothesis(
|
|
||||||
ys=y_star.ys[:],
|
|
||||||
log_prob=new_y_star_log_prob,
|
|
||||||
# Caution: Use y_star.decoder_state here
|
|
||||||
decoder_state=y_star.decoder_state,
|
|
||||||
)
|
|
||||||
B.append(new_y_star)
|
|
||||||
|
|
||||||
# Second, choose other labels
|
|
||||||
for i, v in enumerate(log_prob.tolist()):
|
|
||||||
if i == blank_id:
|
|
||||||
continue
|
|
||||||
new_ys = y_star.ys + [i]
|
|
||||||
new_log_prob = y_star.log_prob + v
|
|
||||||
new_hyp = Hypothesis(
|
|
||||||
ys=new_ys,
|
|
||||||
log_prob=new_log_prob,
|
|
||||||
decoder_state=decoder_state,
|
|
||||||
)
|
|
||||||
A.append(new_hyp)
|
|
||||||
u += 1
|
|
||||||
# check whether B contains more than "beam" elements more probable
|
|
||||||
# than the most probable in A
|
|
||||||
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
|
|
||||||
B = sorted(
|
|
||||||
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
|
|
||||||
key=lambda hyp: hyp.log_prob,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
if len(B) >= beam:
|
|
||||||
B = B[:beam]
|
|
||||||
break
|
|
||||||
t += 1
|
|
||||||
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
|
|
||||||
ys = best_hyp.ys[1:] # [1:] to remove the blank
|
|
||||||
return ys
|
|
1
egs/librispeech/ASR/transducer_lstm/beam_search.py
Symbolic link
1
egs/librispeech/ASR/transducer_lstm/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../transducer_stateless/beam_search.py
|
@ -114,6 +114,14 @@ def get_parser():
|
|||||||
help="Used only when --decoding-method is beam_search",
|
help="Used only when --decoding-method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -124,14 +132,10 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"encoder_out_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"encoder_hidden_size": 1024,
|
"encoder_hidden_size": 2048,
|
||||||
"num_encoder_layers": 4,
|
"num_encoder_layers": 6,
|
||||||
"proj_size": 512,
|
"proj_size": 512,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
# decoder params
|
|
||||||
"decoder_embedding_dim": 1024,
|
|
||||||
"num_decoder_layers": 4,
|
|
||||||
"decoder_hidden_dim": 512,
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -153,11 +157,9 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict):
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.decoder_embedding_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
num_layers=params.num_decoder_layers,
|
context_size=params.context_size,
|
||||||
hidden_dim=params.decoder_hidden_dim,
|
|
||||||
output_dim=params.encoder_out_dim,
|
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
@ -14,24 +14,30 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
# TODO(fangjun): Support switching between LSTM and GRU
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
|
"""This class modifies the stateless decoder from the following paper:
|
||||||
|
|
||||||
|
RNN-transducer with stateless prediction network
|
||||||
|
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||||
|
|
||||||
|
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||||
|
network. Different from the above paper, it adds an extra Conv1d
|
||||||
|
right after the embedding layer.
|
||||||
|
|
||||||
|
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
num_layers: int,
|
context_size: int,
|
||||||
hidden_dim: int,
|
|
||||||
output_dim: int,
|
|
||||||
embedding_dropout: float = 0.0,
|
|
||||||
rnn_dropout: float = 0.0,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -41,16 +47,9 @@ class Decoder(nn.Module):
|
|||||||
Dimension of the input embedding.
|
Dimension of the input embedding.
|
||||||
blank_id:
|
blank_id:
|
||||||
The ID of the blank symbol.
|
The ID of the blank symbol.
|
||||||
num_layers:
|
context_size:
|
||||||
Number of LSTM layers.
|
Number of previous words to use to predict the next word.
|
||||||
hidden_dim:
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
Hidden dimension of LSTM layers.
|
|
||||||
output_dim:
|
|
||||||
Output dimension of the decoder.
|
|
||||||
embedding_dropout:
|
|
||||||
Dropout rate for the embedding layer.
|
|
||||||
rnn_dropout:
|
|
||||||
Dropout for LSTM layers.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
@ -58,40 +57,42 @@ class Decoder(nn.Module):
|
|||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
self.embedding_dropout = nn.Dropout(embedding_dropout)
|
|
||||||
# TODO(fangjun): Use layer normalized LSTM
|
|
||||||
self.rnn = nn.LSTM(
|
|
||||||
input_size=embedding_dim,
|
|
||||||
hidden_size=hidden_dim,
|
|
||||||
num_layers=num_layers,
|
|
||||||
batch_first=True,
|
|
||||||
dropout=rnn_dropout,
|
|
||||||
)
|
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
self.output_linear = nn.Linear(hidden_dim, output_dim)
|
|
||||||
|
|
||||||
def forward(
|
assert context_size >= 1, context_size
|
||||||
self,
|
self.context_size = context_size
|
||||||
y: torch.Tensor,
|
if context_size > 1:
|
||||||
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
self.conv = nn.Conv1d(
|
||||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
in_channels=embedding_dim,
|
||||||
|
out_channels=embedding_dim,
|
||||||
|
kernel_size=context_size,
|
||||||
|
padding=0,
|
||||||
|
groups=embedding_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y:
|
y:
|
||||||
A 2-D tensor of shape (N, U) with BOS prepended.
|
A 2-D tensor of shape (N, U) with blank prepended.
|
||||||
states:
|
need_pad:
|
||||||
A tuple of two tensors containing the states information of
|
True to left pad the input. Should be True during training.
|
||||||
LSTM layers in this decoder.
|
False to not pad the input. Should be False during inference.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing:
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
|
|
||||||
- rnn_output, a tensor of shape (N, U, C)
|
|
||||||
- (h, c), containing the state information for LSTM layers.
|
|
||||||
Both are of shape (num_layers, N, C)
|
|
||||||
"""
|
"""
|
||||||
embeding_out = self.embedding(y)
|
embeding_out = self.embedding(y)
|
||||||
embeding_out = self.embedding_dropout(embeding_out)
|
if self.context_size > 1:
|
||||||
rnn_out, (h, c) = self.rnn(embeding_out, states)
|
embeding_out = embeding_out.permute(0, 2, 1)
|
||||||
out = self.output_linear(rnn_out)
|
if need_pad is True:
|
||||||
|
embeding_out = F.pad(
|
||||||
return out, (h, c)
|
embeding_out, pad=(self.context_size - 1, 0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# During inference time, there is no need to do extra padding
|
||||||
|
# as we only need one output
|
||||||
|
assert embeding_out.size(-1) == self.context_size
|
||||||
|
embeding_out = self.conv(embeding_out)
|
||||||
|
embeding_out = embeding_out.permute(0, 2, 1)
|
||||||
|
return embeding_out
|
||||||
|
@ -131,6 +131,14 @@ def get_parser():
|
|||||||
help="The lr_factor for Noam optimizer",
|
help="The lr_factor for Noam optimizer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -172,9 +180,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
|
||||||
input features.
|
|
||||||
|
|
||||||
- attention_dim: Hidden dim for multi-head attention model.
|
- attention_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
@ -195,14 +200,10 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"encoder_out_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"encoder_hidden_size": 1024,
|
"encoder_hidden_size": 2048,
|
||||||
"num_encoder_layers": 4,
|
"num_encoder_layers": 6,
|
||||||
"proj_size": 512,
|
"proj_size": 512,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
# decoder params
|
|
||||||
"decoder_embedding_dim": 1024,
|
|
||||||
"num_decoder_layers": 4,
|
|
||||||
"decoder_hidden_dim": 512,
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"warm_step": 80000, # For the 100h subset, use 8k
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -227,12 +228,11 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict):
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.decoder_embedding_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
num_layers=params.num_decoder_layers,
|
context_size=params.context_size,
|
||||||
hidden_dim=params.decoder_hidden_dim,
|
|
||||||
output_dim=params.encoder_out_dim,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
@ -573,11 +573,11 @@ def run(rank, world_size, args):
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters() if p.requires_grad])
|
num_param = sum([p.numel() for p in model.parameters() if p.requires_grad])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user