mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix DDP issue; Change configurations, reducing subsampling factor; increase sequence length.
This commit is contained in:
parent
45f5e9981d
commit
f0264bed1b
@ -119,6 +119,6 @@ class ChunkDecoder(nn.Module):
|
|||||||
# occasionally print out average logprob per position in the chunk.
|
# occasionally print out average logprob per position in the chunk.
|
||||||
l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1))
|
l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1))
|
||||||
l = l.to('cpu').tolist()
|
l = l.to('cpu').tolist()
|
||||||
logging.info(l"Logprobs per position in chunk: {l}")
|
logging.info(f"Logprobs per position in chunk: {l}")
|
||||||
|
|
||||||
return logprobs
|
return logprobs
|
||||||
|
|||||||
@ -37,7 +37,10 @@ from icefall.utils import str2bool
|
|||||||
class LmDataset(torch.utils.data.IterableDataset):
|
class LmDataset(torch.utils.data.IterableDataset):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
file_list_fn: Path,
|
file_list_fn: Path,
|
||||||
bytes_per_segment: int = 200):
|
bytes_per_segment: int = 200,
|
||||||
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize LmDataset object. Args:
|
Initialize LmDataset object. Args:
|
||||||
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
|
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
|
||||||
@ -48,6 +51,7 @@ class LmDataset(torch.utils.data.IterableDataset):
|
|||||||
self.files = []
|
self.files = []
|
||||||
self.num_bytes = []
|
self.num_bytes = []
|
||||||
self.bytes_per_segment = bytes_per_segment
|
self.bytes_per_segment = bytes_per_segment
|
||||||
|
self.ddp_rank = get_rank()
|
||||||
|
|
||||||
num_bytes = []
|
num_bytes = []
|
||||||
with open(file_list_fn) as f:
|
with open(file_list_fn) as f:
|
||||||
@ -64,18 +68,23 @@ class LmDataset(torch.utils.data.IterableDataset):
|
|||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
num_workers = (1 if worker_info is None else worker_info.num_workers)
|
num_workers = (1 if worker_info is None else worker_info.num_workers)
|
||||||
|
|
||||||
|
# world_size is for ddp training, num_workers for data-loader worker threads.
|
||||||
tot_workers = num_workers * get_world_size()
|
tot_workers = num_workers * get_world_size()
|
||||||
|
|
||||||
|
|
||||||
self.num_segments = tot_bytes // (bytes_per_segment * tot_workers)
|
self.num_segments = tot_bytes // (bytes_per_segment * tot_workers)
|
||||||
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
# id includes both worker (within training job) and rank of training job
|
# id includes both worker (within training job) and rank of training job
|
||||||
my_id = (0 if worker_info is None else worker_info.id) + 1000 * get_rank()
|
my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank
|
||||||
|
|
||||||
seed = random.randint(0, 10000) + my_id
|
seed = random.randint(0, 10000) + my_id
|
||||||
logging.info(f"seed={seed}, num_segments={self.num_segments}")
|
# the next line is because, for some reason, when we ran with --worle-size more than 1,
|
||||||
|
# this info message was not printed out.
|
||||||
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
|
logging.info(f"my_id={my_id}, seed={seed}, num_segments={self.num_segments}")
|
||||||
rng = np.random.default_rng(seed=seed)
|
rng = np.random.default_rng(seed=seed)
|
||||||
for n in range(self.num_segments):
|
for n in range(self.num_segments):
|
||||||
# np.random.multinomial / np.random.Generator.multinomial has an interface
|
# np.random.multinomial / np.random.Generator.multinomial has an interface
|
||||||
|
|||||||
@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
type=str,
|
type=str,
|
||||||
default="2,4,5,6",
|
default="2,4,8",
|
||||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--downsampling-factor",
|
"--downsampling-factor",
|
||||||
type=str,
|
type=str,
|
||||||
default="1,2,4,8",
|
default="1,2,4",
|
||||||
help="Downsampling factor for each stack of encoder layers.",
|
help="Downsampling factor for each stack of encoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -137,21 +137,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--feedforward-dim",
|
"--feedforward-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="512,768,1024,1536",
|
default="768,1024,1536",
|
||||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-heads",
|
"--num-heads",
|
||||||
type=str,
|
type=str,
|
||||||
default="4,4,6,8",
|
default="4,4,8",
|
||||||
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dim",
|
"--encoder-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="192,256,384,512",
|
default="256,384,512",
|
||||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -186,7 +186,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-unmasked-dim",
|
"--encoder-unmasked-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="192,192,256,256",
|
default="192,192,256",
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||||
)
|
)
|
||||||
@ -194,7 +194,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cnn-module-kernel",
|
"--cnn-module-kernel",
|
||||||
type=str,
|
type=str,
|
||||||
default="31,31,15,15",
|
default="31,31,15",
|
||||||
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
||||||
"a single int or comma-separated list.",
|
"a single int or comma-separated list.",
|
||||||
)
|
)
|
||||||
@ -214,7 +214,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -481,8 +480,8 @@ def get_params() -> AttributeDict:
|
|||||||
"valid_interval": 3000,
|
"valid_interval": 3000,
|
||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"bytes_per_segment": 1024,
|
"bytes_per_segment": 2048,
|
||||||
"batch_size": 64,
|
"batch_size": 40,
|
||||||
"train_file_list": "train.txt",
|
"train_file_list": "train.txt",
|
||||||
"valid_file_list": "valid.txt",
|
"valid_file_list": "valid.txt",
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user