Source code for pahelix.utils.data_utils

#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

| Tools for data.

import numpy as np
import os
import random

[docs]def save_data_list_to_npz(data_list, npz_file): """ Save a list of data to the npz file. Each data is a dict of numpy ndarray. Args: data_list(list): a list of data. npz_file(str): the npz file location. """ keys = data_list[0].keys() merged_data = {} for key in keys: if len(np.array(data_list[0][key]).shape) == 0: lens = np.ones(len(data_list)).astype('int') values = np.array([data[key] for data in data_list]) singular = 1 else: lens = np.array([len(data[key]) for data in data_list]) values = np.concatenate([data[key] for data in data_list], 0) singular = 0 merged_data[key] = values merged_data[key + '.seq_len'] = lens merged_data[key + '.singular'] = singular np.savez_compressed(npz_file, **merged_data)
[docs]def load_npz_to_data_list(npz_file): """ Reload the data list save by ``save_data_list_to_npz``. Args: npz_file(str): the npz file location. Returns: a list of data where each data is a dict of numpy ndarray. """ def _split_data(values, seq_lens, singular): res = [] s = 0 for l in seq_lens: if singular == 0: res.append(values[s: s + l]) else: res.append(values[s]) s += l return res merged_data = np.load(npz_file, allow_pickle=True) names = [name for name in merged_data.keys() if not name.endswith('.seq_len') and not name.endswith('.singular')] data_dict = {} for name in names: data_dict[name] = _split_data( merged_data[name], merged_data[name + '.seq_len'], merged_data[name + '.singular']) data_list = [] n = len(data_dict[names[0]]) for i in range(n): data = {name:data_dict[name][i] for name in names} data_list.append(data) return data_list
[docs]def get_part_files(data_path, trainer_id, trainer_num): """ Split the files in data_path so that each trainer can train from different examples. """ filenames = os.listdir(data_path) random.shuffle(filenames) part_filenames = [] for (i, filename) in enumerate(filenames): if i % trainer_num == trainer_id: part_filenames.append(data_path + '/' + filename) return part_filenames