mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use larger valid set; get --print-diagnostics=True to work
This commit is contained in:
parent
105fb56db4
commit
da80241179
@ -40,7 +40,8 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
bytes_per_segment: int = 200,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
training: bool = True
|
||||
training: bool = True,
|
||||
skip_to_batch_idx: int = 0,
|
||||
):
|
||||
"""
|
||||
Initialize LmDataset object. Args:
|
||||
@ -48,8 +49,10 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
e.g. a line might contain the text "64324 foo/abc.txt".
|
||||
(filenames can not contain spaces).
|
||||
bytes_per_segment: the number of bytes in each segment of data.
|
||||
skip_to_batch_idx: if provided, the first time we iterate we will skip this many batches.
|
||||
"""
|
||||
self.training = training
|
||||
self.skip_to_batch_idx = skip_to_batch_idx
|
||||
self.files = []
|
||||
self.num_bytes = []
|
||||
self.bytes_per_segment = bytes_per_segment
|
||||
@ -88,6 +91,12 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
logging.info(f"my_id={my_id}, seed={seed}, num_segments={self.num_segments}")
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
|
||||
skip_to_batch_idx = self.skip_to_batch_idx
|
||||
if skip_to_batch_idx != 0:
|
||||
logging.info(f"skip-to-batch-idx={skip_to_batch_idx}")
|
||||
self.skip_to_batch_idx = 0 # so only the 1st time we iterate, we respect this.
|
||||
|
||||
for n in range(self.num_segments):
|
||||
# np.random.multinomial / np.random.Generator.multinomial has an interface
|
||||
# where it gives counts of different categories, instead of the chosen category,
|
||||
@ -97,6 +106,9 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
file_idx, = np.nonzero(rng.multinomial(1, self.probs))
|
||||
file_idx, = file_idx
|
||||
|
||||
if n < skip_to_batch_idx:
|
||||
continue
|
||||
|
||||
fn = self.files[file_idx]
|
||||
num_bytes = self.num_bytes[file_idx]
|
||||
|
||||
@ -139,5 +151,5 @@ if __name__ == '__main__':
|
||||
|
||||
# cd libriheavy/LM
|
||||
# find /ceph-data3/xiaoyu/librilight_text/output_text_large_cleaned -name text.txt -exec stat --printf='%s ' {} \; -print > files.txt
|
||||
# head -n 2 files.txt > valid.txt
|
||||
# tail -n +3 files.txt > train.txt
|
||||
# head -n 4 files.txt > valid.txt
|
||||
# tail -n +5 files.txt > train.txt
|
||||
|
||||
@ -762,9 +762,6 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
|
||||
@ -991,18 +988,23 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
train = LmDataset(params.train_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment,)
|
||||
bytes_per_segment=params.bytes_per_segment,
|
||||
skip_to_batch_idx=getattr(params, 'cur_batch_idx', 0))
|
||||
|
||||
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(
|
||||
dataset=train,
|
||||
batch_size=params.batch_size,
|
||||
batch_size=batch_size,
|
||||
num_workers=params.num_workers,
|
||||
drop_last=True)
|
||||
|
||||
|
||||
valid = LmDataset(params.valid_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment)
|
||||
valid_dl = torch.utils.data.DataLoader(
|
||||
dataset=valid,
|
||||
batch_size=params.batch_size,
|
||||
batch_size=batch_size,
|
||||
num_workers=params.num_workers,
|
||||
drop_last=False)
|
||||
|
||||
@ -1017,7 +1019,7 @@ def run(rank, world_size, args):
|
||||
# to let it know how many tokens we have processed so far, and have a
|
||||
# soft-cutoff lr_tokens measured in tokens.
|
||||
# scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1 + params.start_batch)
|
||||
fix_random_seed(params.seed + epoch)
|
||||
# the above will affect random seeds in the dataloaders.
|
||||
|
||||
if tb_writer is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user