Source code for pahelix.datasets.inmemory_dataset

#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
In-memory dataset.
"""

import os
from os.path import join, exists

import numpy as np

from pgl.utils.data import Dataloader

from pahelix.utils.data_utils import save_data_list_to_npz, load_npz_to_data_list
from pahelix.utils.basic_utils import mp_pool_map


__all__ = ['InMemoryDataset']


[docs]class InMemoryDataset(object): """ Description: The InMemoryDataset manages ``data_list`` which is a list of `data` and the `data` is a dict of numpy ndarray. And each dict has the same keys. It works like a list: you can call `dataset[i] to get the i-th element of the ``data_list`` and call `len(dataset)` to get the length of ``data_list``. The ``data_list`` can be cached in npz files by calling `dataset.save_data(data_path)` and after that, call `InMemoryDataset(data_path)` to reload. Attributes: data_list(list): a list of dict of numpy ndarray. Example: .. code-block:: python data_list = [{'a': np.zeros([4, 5])}, {'a': np.zeros([7, 5])}] dataset = InMemoryDataset(data_list=data_list) print(len(dataset)) dataset.save_data('./cached_npz') # save data_list to ./cached_npz dataset2 = InMemoryDataset(npz_data_path='./cached_npz') # will load the saved `data_list` print(len(dataset)) """ def __init__(self, data_list=None, npz_data_path=None, npz_data_files=None): """ Users can either directly pass the ``data_list`` or pass the `data_path` from which the cached ``data_list`` will be loaded. Args: data_list(list): a list of dict of numpy ndarray. data_path(str): the path to the cached npz path. """ super(InMemoryDataset, self).__init__() self.data_list = data_list self.npz_data_path = npz_data_path self.npz_data_files = npz_data_files if not npz_data_path is None: self.data_list = self._load_npz_data_path(npz_data_path) if not npz_data_files is None: self.data_list = self._load_npz_data_files(npz_data_files) def _load_npz_data_path(self, data_path): data_list = [] files = [f for f in os.listdir(data_path) if f.endswith('.npz')] files = sorted(files) for f in files: data_list += load_npz_to_data_list(join(data_path, f)) return data_list def _load_npz_data_files(self, data_files): data_list = [] for f in data_files: data_list += load_npz_to_data_list(f) return data_list def _save_npz_data(self, data_list, data_path, max_num_per_file=10000): if not exists(data_path): os.makedirs(data_path) n = len(data_list) for i in range(int((n - 1) / max_num_per_file) + 1): filename = 'part-%06d.npz' % i sub_data_list = self.data_list[i * max_num_per_file: (i + 1) * max_num_per_file] save_data_list_to_npz(sub_data_list, join(data_path, filename))
[docs] def save_data(self, data_path): """ Save the ``data_list`` to the disk specified by ``data_path`` with npz format. After that, call `InMemoryDataset(data_path)` to reload the ``data_list``. Args: data_path(str): the path to the cached npz path. """ self._save_npz_data(self.data_list, data_path)
def __getitem__(self, key): if isinstance(key, slice): start, stop, step = key.indices(len(self)) dataset = InMemoryDataset( data_list=[self[i] for i in range(start, stop, step)]) return dataset elif isinstance(key, int) or \ isinstance(key, np.int64) or \ isinstance(key, np.int32): return self.data_list[key] elif isinstance(key, list): dataset = InMemoryDataset( data_list=[self[i] for i in key]) return dataset else: raise TypeError('Invalid argument type: %s of %s' % (type(key), key)) def __len__(self): return len(self.data_list)
[docs] def transform(self, transform_fn, num_workers=4, drop_none=False): """ Inplace apply `transform_fn` on the `data_list` with multiprocess. """ data_list = mp_pool_map(self.data_list, transform_fn, num_workers) if drop_none: self.data_list = [data for data in data_list if not data is None] else: self.data_list = data_list
[docs] def get_data_loader(self, batch_size, num_workers=4, shuffle=False, collate_fn=None): """ It returns an batch iterator which yields a batch of data. Firstly, a sub-list of `data` of size ``batch_size`` will be draw from the ``data_list``, then the function ``collate_fn`` will be applied to the sub-list to create a batch and yield back. This process is accelerated by multiprocess. Args: batch_size(int): the batch_size of the batch data of each yield. num_workers(int): the number of workers used to generate batch data. Required by multiprocess. shuffle(bool): whether to shuffle the order of the ``data_list``. collate_fn(function): used to convert the sub-list of ``data_list`` to the aggregated batch data. Yields: the batch data processed by ``collate_fn``. """ return Dataloader(self, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, collate_fn=collate_fn)