icefall/egs/libriheavy/LM/zipformer1/lm_datamodule.py
2023-06-20 01:58:27 +08:00

172 lines
7.1 KiB
Python

# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
import numpy as np
import random
from pathlib import Path
from typing import Any, Dict, Optional
from icefall.dist import get_world_size, get_rank
import torch
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class LmDataset(torch.utils.data.IterableDataset):
def __init__(self,
file_list_fn: Path,
bytes_per_segment: int = 200,
training: bool = True,
):
"""
Initialize LmDataset object. This keeps no state, it just gives you a totally random
segment each time. The training files are just viewed as sequences of bytes, from which
we select chunks of a fixed size. In training mode we just loop infinitely, and let
the training code decide when to stop based on the count of tokens. In test mode
we loop so that we see each byte about once.
Args:
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
e.g. a line might contain the text "64324 foo/abc.txt".
(filenames can not contain spaces).
bytes_per_segment: the number of bytes in each segment of data.
"""
self.training = training
self.files = []
self.num_bytes = []
self.bytes_per_segment = bytes_per_segment
self.ddp_rank = get_rank()
num_bytes = []
with open(file_list_fn) as f:
for line in f.readlines():
line = line.strip() # remove newline
num_bytes = line.split()[0] # a str
fn = line[len(num_bytes) + 1:] # this works even if fn has spaces in
self.files.append(fn)
self.num_bytes.append(int(num_bytes))
# For purposes of choosing the possible start-positions of a segment: we
# need to pad on the left by bytes_per_segment - 1. This is part of a
# scheme to ensure that each byte in each training file is chosen with
# equal probability, while also choosing different shifts of the data
# with equal probability. We end up padding with zeroes if we
# are outside the file either on the left or the right.
pad = self.bytes_per_segment - 1
tot_positions = sum([ x + pad for x in self.num_bytes])
self.probs = np.array([ (x + pad) / tot_positions for x in self.num_bytes ])
self.tot_positions = tot_positions
worker_info = torch.utils.data.get_worker_info()
num_workers = (1 if worker_info is None else worker_info.num_workers)
# num_workers for data-loader worker threads; world_size is for ddp training.
tot_workers = num_workers * get_world_size()
self.num_segments = float('inf') if training else 1 + tot_positions // (bytes_per_segment * tot_workers)
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
# id includes both worker (within training job) and rank of training job
my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank
# note: the seed depends on the current random state, which will be different
# depending on the DDP worker id and also depending which batch we restarted
# training on. This does not guarantee that you get repeatability if you
# restart training, but it does ensure you don't see exactly repeated data.
seed = (random.randint(0, 10000) if self.training else 0) + my_id
# the next line is because, for some reason, when we ran with --worle-size more than 1,
# this info message was not printed out.
logging.getLogger().setLevel(logging.INFO)
logging.info(f"my_id={my_id}, seed={seed}, num_segments={self.num_segments}")
# use numpy's generator, not random's, because we need np.random.multinomial.
rng = np.random.default_rng(seed=seed)
n = 0
while n < self.num_segments: # if self.num_segments is infinity, just keep going.
n += 1
# np.random.multinomial / np.random.Generator.multinomial has an interface
# where it gives counts of different categories, instead of the chosen category,
# so we need to use np.nonzero to get the chosen category (i.e. the file index)
# np.nonzero will give an array per dim, so file_idx,
# gives the array of nonzero index
file_idx, = np.nonzero(rng.multinomial(1, self.probs))
file_idx, = file_idx
fn = self.files[file_idx]
num_bytes = self.num_bytes[file_idx]
# begin_pos, end_pos are the begin,end of a range from which we'll pick
# randomly, for where the start of the segment might be. We only
# guarantee that a segment should contain at most one byte of data;
# this helps ensure that each byte is chosen with the exact same probability,
# which is easier for analysis.
begin_pos = - (self.bytes_per_segment - 1)
end_pos = max(1, num_bytes - 1)
begin, = rng.integers(low=begin_pos, high=end_pos, size=1)
with open(fn, "rb") as f:
if begin >= 0:
f.seek(begin)
b = f.read(self.bytes_per_segment) # b is bytes object
else:
b = f.read(self.bytes_per_segment + begin)
#b = b'\0' * -begin + f.read(self.bytes_per_segment + begin)
if len(b) < self.bytes_per_segment:
b = b + b'\0' * (self.bytes_per_segment - len(b))
yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long)
def num_tokens(self):
# Returns the total number of tokens, including padding tokens, in
# the dataset; this is for purposes of figuring out how many we
# epochs we have trained for.
return self.tot_positions
def _test():
l = LmDataset('files.txt')
d = torch.utils.data.DataLoader(
dataset=l, batch_size=5, num_workers=4, drop_last=True)
for batch in d:
logging.info("batch shape: ", batch.shape)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
_test()
# cd libriheavy/LM
# find /ceph-data3/xiaoyu/librilight_text/output_text_large_cleaned -name text.txt -exec stat --printf='%s ' {} \; -print > files.txt
# head -n 4 files.txt > valid.txt
# tail -n +5 files.txt > train.txt