使用有限内存对巨型数据文件排序
本文迁移自笔者 CSDN 账号下文章
运行环境
- JDK:openJDK13
- 内存:200M (通过 VM 参数 -Xmx200M 指定)
- 目标数据文件:raw.data (1.72G)
基本思路(分治)
- 切分:从目标数据文件中读取数据,读取一定数量后对读取到的数据进行排序,并生成临时排序文件,重复此过程,将原始数据文件分割为若干个已排序的数据文件
- 合并:根据上一阶段得到的分组文件数量,如果内存不足以一次创建所有文件的指针,则需要进行多次合并。合并时,将若干个临时数据文件合并为更大的数据文件,使用归并排序思想,使用优先级队列辅助。重复此过程直到生成的临时数据文件个数足够少。
实现
定义数据对象
- 定义数据格式、数据对象与字符串的互转方法,使程序可以从输入流中分离到数据对象。下文代码中,使用
toString()和String参数的构造方法实现 - 定义排序规则:实现
Comparable<T>接口
生成乱序数据文件
在数据类中实现一个随机生成方法,在工具类中开几个线程,因为这里定义的数据对象略大,所以本实例中向文件中写入了 5000_0000 个数据对象
分割
- 规定系统可分配的工作线程数量,因为内存十分有限,如果线程太多,则线程本身就会造成 OOM ,这里规定了 12 个线程。因为过分限制内存,整个过程效率比较底下。
- 规定分割尺寸,这个值影响每个线程实际执行的任务量、生成的临时文件数量,即分割的粒度。这里指定为 80_0000
- 多线程执行任务,针对读到的指定数量的数据对象进行排序、写入临时数据文件,因为各线程只能共享一个输入流(使用的 BufferedReader 是线程安全的,无需同步控制),所以可能多线程的优势并不明显,不过排序任务、写文件可并行进行。
- 本实例中,分割生成了 754 个(从零开始编号)临时数据文件

合并
- 因为一旦在上一阶段生成了太多临时文件,同时创建太多输入流可能导致 OOM ,可以进行多次合并,本实例中,进行了两次合并。
- 封装一个类似链表的类,提供 next, hasNext 方法,该类用于从临时数据文件中读取数据对象,并预读下一个读取的对象,优先级队列可以根据它对数据文件进行排序,以保证合并后的数据顺序。
- 合并时,每个线程根据为自己分配的临时文件数创建一个指定容量的优先级队列,将各个文件打开,进行合并,每个线程生成一个新的临时文件
- 针对各线程生成的临时文件进行合并,得到最终排序后的文件。


上图即为合并时的临时文件,和得到的最终排序文件(sorted.data)。比原数据文件小了接近 50M ,因为在处理换行符时未统一(生成 raw.data 时使用的换行符是 System.lineSeparator() 在 windows 下为 "\r\n",生成 sorted.data 时使用的是 '\n'),导致每一行缺少一个 '\r' ,数据共 5000_0000 行,所以空间约为 50M
闲扯(可忽略)
近期阅读 openJDK13 源码,觉得十分枯燥,决定写点啥,于是有了上面的东西。
撸代码之前,觉得这个大文件排序不是很难,为了加大难度,决定不搜索,全凭源码中的注释,另外使用线程池,支持需要兼顾数据元对象的拓展,即允许替换排序对象类,使之可以实现针对多种对象的大文件排序
所以,为了达到上述的需求,尝试了很多方法,最后决定使用抽象工厂,而这一方式也带来了一些方便,比如在分割、排序时,使用针对流的链式调用即可实现将读到的字符串数组转为数据对象并进行排序、生成排序后的字符数组并写入临时数据文件。
开始写的时候行云流水,在处理掉明显的错误后,发现最终得到的文件比原数据文件少了几百M,考虑到上文提到的换行符问题,相差的大小还是不对,冒着卡死的风险将排序后的数据文件拖到 Notepad++ 后,发现竟然少了40 万行数据而且顺序还有问题。无奈从头开始,将源数据文件控制到 100 行、200行、1000行并调整线程数、分割粒度等参数,将所有问题处理掉才最终完成。
整个过程,还是过于高估了自己的能力。虽说没有使用搜索,但是耗时远超出了当时的预期,不过,写代码还是比读代码有趣。
最后,放码过来
本代码未进行足够测试,只通过了数据文件为 100~2000 行以及 5000_0000 行的测试,其他测试均未进行
因为使用断言,而且断言影响程序功能,如果需要运行,需要指定虚拟机参数: -ea


Main.java
package work.cxlm;
import work.cxlm.helper.BigFileSorter;
import work.cxlm.helper.Element;
import work.cxlm.helper.Generator;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
public class Main {
public static void main(String[] args) { // 通过调整注释以进行不同任务
String originFileName = "D:/works/Java/try55_bigsort/raw.data";
// generateDataFile(originFileName);
// ensureMemorize();
bigFileSort(80_0000, originFileName);
}
private static void bigFileSort(int divideSize, String originFilename) {
try {
BigFileSorter<Element> sorter = new BigFileSorter<>(originFilename, divideSize, Element::new);
sorter.startSort();
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
// 一定会 OOM 的方法,仅用于测试限定内存是否生效
// 经测试,200M 内存可以存放 1633812 个对象
private static void ensureMemorize() {
List<Element> eles = new LinkedList<>();
int counter = 0;
try {
for (; ; ) {
eles.add(Element.random());
counter++;
}
} catch (OutOfMemoryError error) {
error.printStackTrace();
eles = null;
System.gc();
System.out.println("指定的空间耗尽,共创建"+counter+"个对象");
}
}
private static void generateDataFile(String filename) {
try {
Generator.generate(filename, 1000);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
}
}
BigFileSorter.java
package work.cxlm.helper;
import java.io.*;
import java.util.*;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
public class BigFileSorter<T extends SortableStringConvertAble> {
private static final int WORKER_COUNT = 12;
private File originFile;
private File targetFile;
private final int DIVIDE_SIZE;
private ConvertFactory<T> factory;
private static final String TEMP_FILENAME_FORMAT = "%s/.temp%04d";
private ThreadPoolExecutor pool = new ThreadPoolExecutor(WORKER_COUNT, WORKER_COUNT, 0, TimeUnit.SECONDS,
new ArrayBlockingQueue<>(WORKER_COUNT), runnable -> new Thread(runnable, "排序任务单元"));
public BigFileSorter(String originPath, int divideSize, ConvertFactory<T> factory) throws FileNotFoundException {
originFile = new File(originPath);
if (!originFile.exists()) {
throw new FileNotFoundException("文件[" + originPath + "]不存在");
}
targetFile = new File(originFile.getParent() + "/sorted.data");
DIVIDE_SIZE = divideSize;
this.factory = factory;
}
public void startSort() {
try {
int tempFileCount = divideAndSort();
mergeSection(tempFileCount);
} catch (Exception e) {
e.printStackTrace();
}
}
private void mergeSection(int tempFileCount) throws IOException, InterruptedException {
System.out.println("正在进行合并");
AtomicInteger fileCounter = new AtomicInteger();
int groupSize = (tempFileCount / WORKER_COUNT) + 1;
String finalFileFormat = "%s/.temp.%02d";
for (int worker = 0; worker < WORKER_COUNT; worker++) {
pool.submit(() -> {
int thisId = fileCounter.getAndIncrement();
int startFileNumber = thisId * groupSize;
int endFileNumber = Math.min(tempFileCount, startFileNumber + groupSize);
PriorityQueue<TempNodeList> nodeQueue = new PriorityQueue<>(groupSize, Comparator.comparing(a -> a.nowNode.val()));
for (int i = startFileNumber; i < endFileNumber; i++) {
File nowFile = new File(String.format(TEMP_FILENAME_FORMAT, originFile.getParent(), i));
try {
nodeQueue.offer(new TempNodeList(nowFile));
} catch (IOException e) {
e.printStackTrace();
}
}
File threadTarget = new File(String.format(finalFileFormat, originFile.getParent(), thisId));
try {
assert threadTarget.createNewFile();
queueToFile(threadTarget, nodeQueue);
} catch (IOException e) {
e.printStackTrace();
}
});
}
pool.shutdown();
System.out.println("等待合并线程结束");
while (!pool.awaitTermination(2, TimeUnit.SECONDS)) {
System.out.print(".");
}
PriorityQueue<TempNodeList> nodeQueue = new PriorityQueue<>(WORKER_COUNT, Comparator.comparing(a -> a.nowNode.val()));
System.out.println("进行目标文件生成");
for (int i = 0; i < WORKER_COUNT; i++) {
File nowFile = new File(String.format(finalFileFormat, originFile.getParent(), i));
assert nowFile.exists();
TempNodeList node = new TempNodeList(nowFile);
if (node.nowNode != null)
nodeQueue.offer(node);
}
assert targetFile.exists() || targetFile.createNewFile();
queueToFile(targetFile, nodeQueue);
System.out.println("完成合并,目标文件:" + targetFile.getName());
}
private void queueToFile(File file, Queue<TempNodeList> nodeQueue) throws IOException {
BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(file));
while (!nodeQueue.isEmpty()) {
TempNodeList nowNodes = nodeQueue.poll();
os.write(nowNodes.now().toString().getBytes());
os.write("\n".getBytes());
if (nowNodes.hasNext()) {
nowNodes.next();
nodeQueue.offer(nowNodes);
}
}
os.flush();
os.close();
}
class TempNodeList {
BufferedReader reader;
T nowNode;
String nextStr;
private boolean end;
TempNodeList(File file) throws IOException {
reader = new BufferedReader(new FileReader(file));
next();
next();
}
public void next() throws IOException {
if (nextStr == null || nextStr.isEmpty()) {
nowNode = null;
} else {
nowNode = factory.create(nextStr);
}
while ((nextStr = reader.readLine()) != null && nextStr.isEmpty())
; // 跳过空行
if (nextStr == null) {
end = true;
}
}
public T now() {
return nowNode;
}
public boolean hasNext() {
return !end;
}
}
private int divideAndSort() throws FileNotFoundException {
System.out.println("正在进行分割排序");
int threadMissionSize = DIVIDE_SIZE / WORKER_COUNT;
BufferedReader reader = new BufferedReader(new FileReader(originFile));
AtomicBoolean end = new AtomicBoolean(false);
AtomicInteger fileCounter = new AtomicInteger();
AtomicInteger debugCounter = new AtomicInteger();
for (int i = 0; i < WORKER_COUNT; i++) {
pool.submit(() -> {
while (!end.get()) {
List<String> threadMissionOriginStringArray = new ArrayList<>(threadMissionSize);
while (threadMissionOriginStringArray.size() < threadMissionSize && !end.get()) {
String readMission = null;
try { // 尝试读取一行
readMission = reader.readLine();
} catch (IOException e) {
e.printStackTrace();
}
if (readMission == null) { // 读到文件末尾,结束读取
end.set(true);
break;
} else if (!readMission.isEmpty()) { // 不是空行,添加到任务集合
threadMissionOriginStringArray.add(readMission);
}
} // -- 领取任务
// 分割、排序、生成临时文件
if (threadMissionOriginStringArray.isEmpty())
return; // 空集合,停止任务
String threadFileName = String.format(TEMP_FILENAME_FORMAT, originFile.getParent(), fileCounter.getAndIncrement());
File tempFile = new File(threadFileName);
try {
assert tempFile.createNewFile() : "创建临时文件";
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(tempFile));
// 将无序字符串列表转为有序字符串列表
threadMissionOriginStringArray.stream().map(factory::create).sorted(Comparator.comparing(SortableStringConvertAble::val)).map(T::toString).forEach(str -> {
try {
debugCounter.incrementAndGet(); // 此处检测得到正确数值
bos.write(str.getBytes()); // 写入临时文件以释放内存
bos.write("\n".getBytes()); // 写入新行
} catch (IOException e) {
e.printStackTrace();
}
});
bos.flush();
bos.close();
} catch (IOException e) {
System.out.println("写入临时文件出错");
e.printStackTrace();
}
} // -- thread's loop to read origin file
}); // -- submit
} // -- loop to submit work thread
System.out.println("等待分割排序线程结束");
while (pool.getCompletedTaskCount() != WORKER_COUNT) {
try {
Thread.sleep(2000);
System.out.print(".");
} catch (InterruptedException e) {
break;
}
}
System.out.println("分割结束,共处理数据[" + debugCounter.get() + "]项");
return fileCounter.get();
}
}
ConvertFactory.java
package work.cxlm.helper;
// 抽象工厂
public interface ConvertFactory<T extends SortableStringConvertAble> {
T create(String str);
}
[点击并拖拽以移动]
Element.java
package work.cxlm.helper;
import java.util.Random;
public class Element extends SortableStringConvertAble {
private static final String[] STRING_POOL = { // 用于生成的随机字符串
"埋まる", "白咲", "花", "索菲", "本居", "日向", "星野", "梦美", "乃愛", "篝",
"越谷", "小鞠", "小林", "宫内", "莲华", "小雪", "北方", "羽未", "白羽", "実り"
};
private static Random random = new Random(System.currentTimeMillis());
private Long value;
private String key;
public Element(String k, Long v) {
key = k;
value = v;
}
public Element(String raw) {
if (raw == null || raw.isEmpty()) return; // 构造空对象
String[] kv = raw.split(":");
if (kv.length != 2) throw new IllegalArgumentException("给定的值[" + raw + "]无法转化为 Element 对象");
key = kv[0];
value = Long.valueOf(kv[1]);
}
public static Element random() {
String randKey = STRING_POOL[random.nextInt(STRING_POOL.length)] + "·" + STRING_POOL[random.nextInt(STRING_POOL.length)];
Long randVal = random.nextLong();
return new Element(randKey, randVal);
}
@Override
public long val() {
return value;
}
@Override
public String toString() {
return key + ":" + value;
}
}
Generator.java
package work.cxlm.helper;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
public class Generator {
static final String SEP = System.lineSeparator();
public static void generate(String filepath, int count) throws IOException, InterruptedException {
File target = new File(filepath);
// 创建父级目录、目标文件
assert target.getParentFile().exists() || target.getParentFile().mkdirs();
assert target.exists() || target.createNewFile();
BufferedOutputStream os = new BufferedOutputStream(new FileOutputStream(target));
int poolSize = 10;
ThreadPoolExecutor pool = new ThreadPoolExecutor(poolSize, poolSize, 0, TimeUnit.SECONDS,
new ArrayBlockingQueue<>(poolSize), runnable -> new Thread(runnable, "创建数据元"));
AtomicInteger atomicCount = new AtomicInteger(count);
for (int i = 0; i < poolSize; i++) {
pool.submit(() -> {
while (atomicCount.decrementAndGet() >= 0) {
byte[] bytesToWrite = (Element.random().toString() + SEP).getBytes();
synchronized (os) {
try {
os.write(bytesToWrite);
} catch (IOException e) {
e.printStackTrace();
}
}
}
});
}
pool.shutdown();
pool.awaitTermination(5, TimeUnit.MINUTES);
os.flush();
os.close();
}
}
SortableStringConvertAble.java
package work.cxlm.helper;
import java.util.Objects;
public abstract class SortableStringConvertAble implements Comparable<SortableStringConvertAble> {
public SortableStringConvertAble() {
}
@Override
public abstract String toString();
abstract protected long val();
@Override
public int compareTo(SortableStringConvertAble o) {
Objects.requireNonNull(o, "数据元不能为 null");
return (int) (val() - o.val());
}
}