# 实现可查询进度的并发任务执行框架

loading

# 一、需求的产生和分析

公司里有两个项目组,考试组有批量的离线文档要生成,题库组则经常有批量的题目进行排重和根据条件批量修改题目的内容。

架构组通过对实际的上线产品进行用户调查,发现这些功能在实际使用时,用户都反应速度很慢,而且提交任务后,不知道任务的进行情况,做没做?做到哪一步了?有哪些成功?哪些失败了?都一概不知道。

架构组和实际的开发人员沟通,他们都说,因为前端提交任务到 Web 后台以后,是一次要处理多个文档和题目,所以速度快不起来。提示用多线程进行改进,实际的开发人员表示多线程没有用过,不知道如何使用,也担心用不好。综合以上情况,架构组决定在公司的基础构件库中提供一个并发任务执行框架,以解决上述用户和业务开发人员的痛点:

  1. 对批量型任务提供统一的开发接口
  2. 在使用上尽可能的对业务开发人员友好
  3. 要求可以查询批量任务的执行进度

# 二、需要做什么

要实现这么一个批量任务并发执行的框架,我们来分析一下我们要做些什么?

  1. 批量任务,为提高性能

必然的我们要使用 java 里的多线程,为了在使用上尽可能的对业务开发人员友好和简单,需要屏蔽一些底层 java 并发编程中的细节,让他们不需要去了解并发容器,阻塞队列,异步任务,线程安全等等方面的知识,只要专心于自己的业务处理即可。

  1. 每个批量任务拥有自己的上下文环境

因为一个项目组里同时要处理的批量任务可能有多个,比如考试组,可能就会有不同的学校的批量的离线文档生成,而题库组则会不同的学科都会有老师同时进行工作,因此需要一个并发安全的容器保存每个任务的属性信息,

  1. 自动清除已完成和过期任务

因为要提供进度查询,系统需要在内存中维护每个任务的进度信息以供查询,但是这种查询又是有时间限制的,一个任务完成一段时间后,就不再提供进度查询了,则就需要我们自动清除已完成和过期任务,用定时轮询吗?

下面是业务示意图;

业务示意图

# 三、具体实现

# 1、用户业务方法的执行结果

一个方法执行的结果有几种可能?三种,成功:按预想的流程出了结果;失败:按按预想的流程没出结果;异常:没按预想的流程抛出了预料之外的错误。因此我们定义了一个枚举,表示这三种情况:

public enum TaskResultType {

    /**
     * 方法执行完成,业务结果也正确
     */
    SUCCESS,

    /**
     * 方法执行完成,业务结果错误
     */
    FAILURE,

    /**
     * 方法执行抛出了异常
     */
    EXCEPTION
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

对于方法的业务执行结果,返回值有很多种可能,基本类型,系统定义的对象类型,用户自定义的对象类型都是存在的,我们需要用泛型来说表示这个结果。同时方法执行失败了,我们还需要告诉用户或者业务开发人员,失败的原因,我们再定义了一个包装任务的结果类。

public class TaskResult<R> {

    /**
     * 方法执行结果
     */
    private final TaskResultType resultType;

    /**
     * 方法执行后的结果数据
     */
    private final R returnValue;

    /**
     * 如果方法失败,这里可以填充原因
     */
    private final String reason;

    public TaskResult(TaskResultType resultType, R returnValue) {
        this(resultType, returnValue, "SUCCESS");
    }

    public TaskResult(TaskResultType resultType, R returnValue, String reason) {
        this.resultType = resultType;
        this.returnValue = returnValue;
        this.reason = reason;
    }

    public TaskResultType getResultType() {
        return resultType;
    }

    public R getReturnValue() {
        return returnValue;
    }

    public String getReason() {
        return reason;
    }

    @Override
    public String toString() {
        return "TaskResult{" +
                "resultType=" + resultType +
                ", returnValue=" + returnValue +
                ", reason='" + reason + '\'' +
                '}';
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

# 2、如何执行用户的业务方法

我们是个框架,用户的业务各种各样,都要放到我们框架里执行,怎么办?当然是定义个接口,我们的框架就只执行这个方法,而使用我们框架的业务方都应该来实现这个接口,当然因为用户业务的数据多样性,意味着我们这个方法的参数也应该用泛型。

public interface ITaskProcessor<T, R> {

    /**
     * 执行任务的方法
     *
     * @param data
     * @return
     */
    TaskResult<R> taskExecute(T data);
}
1
2
3
4
5
6
7
8
9
10

# 3、用户如何提交工作和查询任务进度

用户在前端提交了工作(JOB)到后台,我们需要提供一种封装机制,让业务开发人员可以将任务的相关信息提交给这个封装机制,用户的需要查询进度的时候,也从这个封装机制中取得,同时我们的封装机制内部也要负责清除已完成任务。

在这个封装机制里我们定义了一个类 JobInfo,抽象了对用户工作的封装,一个工作可以包含多个子任务(TASK),这个 JobInfo 中就包括了这个工作的相关信息,比如工作名,用以区分框架中唯一的工作,也可以避免重复提交,也方便查询时快速定位工作,除了工作名以外,工作中任务的列表,工作中任务的处理器都在其中定义。

public class JobInfo<R> {

    /**
     * 工作名,用以区分框架中唯一的工作
     */
    private final String jobName;

    /**
     * 工作中任务的长度
     */
    private final int jobLength;

    /**
     * 处理工作中任务的处理器
     */
    private final ITaskProcessor<?, R> taskProcessor;

    /**
     * 任务的成功次数
     */
    private AtomicInteger successCount;

    /**
     * 工作中任务目前已经处理的次数
     */
    private AtomicInteger taskProcessCount;

    /**
     * 存放每个任务的处理结果,供查询用
     */
    private LinkedBlockingDeque<TaskResult<R>> taskResultQueue;

    /**
     * 保留的工作的结果信息供查询的时长
     */
    private final long expireTime;

    /**
     * 定时清除缓存类
     */
    private static CheckJobProcessor checkJob = CheckJobProcessor.getInstance();

    public JobInfo(String jobName, int jobLength, ITaskProcessor<?, R> taskProcessor, long expireTime) {
        this.jobName = jobName;
        this.jobLength = jobLength;
        this.taskProcessor = taskProcessor;
        this.expireTime = expireTime;
        successCount = new AtomicInteger(0);
        taskProcessCount = new AtomicInteger(0);
        taskResultQueue = new LinkedBlockingDeque<>(jobLength);
    }

    /**
     * 获取任务成功的处理次数
     *
     * @return
     */
    public int getSuccessCount() {
        return successCount.get();
    }

    /**
     * 提供工作中失败的次数
     *
     * @return
     */
    public int getFailCount() {
        return taskProcessCount.get() - successCount.get();
    }

    /**
     * 获取总共任务的处理次数
     *
     * @return
     */
    public int getTaskProcessCount() {
        return taskProcessCount.get();
    }

    public ITaskProcessor<?, R> getTaskProcessor() {
        return taskProcessor;
    }

    public int getJobLength() {
        return jobLength;
    }

    /**
     * 提供工作的整体进度信息
     *
     * @return
     */
    public String getTotalProcess() {
        return "Success[" + successCount.get() + "] / Current[" + taskProcessCount.get()
                + "], Total[" + jobLength + "]";
    }

    /**
     * 提供工作中每个任务的处理结果
     *
     * @return
     */
    public List<TaskResult<R>> getTaskResult() {
        List<TaskResult<R>> taskResultList = new LinkedList<>();
        TaskResult<R> taskResult;
        while ((taskResult = taskResultQueue.pollFirst()) != null) {
            taskResultList.add(taskResult);
        }
        return taskResultList;
    }

    /**
     * 每个任务处理完成后,记录任务的处理结果,因为从业务应用的角度来说,对查询任务进度数据的一致性要不高,
     * 我们保证最终一致性即可,无需对整个方法加锁
     *
     * @param taskResult
     */
    public void addTaskResult(TaskResult<R> taskResult) {
        if (TaskResultType.SUCCESS.equals(taskResult.getResultType())) {
            successCount.incrementAndGet();
        }
        taskProcessCount.incrementAndGet();
        taskResultQueue.addLast(taskResult);
        // 当所有的任务都执行完毕时,将整个工作放入定时缓存,到期后清除
        if (taskProcessCount.get() == jobLength) {
            checkJob.putJob(jobName, expireTime);
        }
    }

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

在这个 JobInfo 中海油很多关于这个工作的方法,例如查询工作进度的方法 getTotalProcess(),查询每个任务处理结果的方法 getTaskResult(),记录每个任务处理结果的方法 addTaskResult()

负责清除已完成任务,我们则交给 CheckJobProcesser 类来完成,定时轮询的机制不够优雅,因此我们选用了 DelayQueue 延迟队列来实现这个功能。

public class CheckJobProcessor {

    /**
     * 存放任务的队列
     */
    private static DelayQueue<ItemVo<String>> queue = new DelayQueue<>();

    /**
     * 单例化
     */
    private static class ProcessorHolder {
        public static CheckJobProcessor processor = new CheckJobProcessor();
    }

    public static CheckJobProcessor getInstance() {
        return ProcessorHolder.processor;
    }

    /**
     * 处理队列中到期任务
     */
    private static class FetchJob implements Runnable {

        private static DelayQueue<ItemVo<String>> queue = CheckJobProcessor.queue;

        /**
         * 缓存的工作信息
         */
        private static Map<String, JobInfo<?>> jobInfoMap = PendingJobPool.getMap();

        @Override
        public void run() {
            while (true) {
                try {
                    // 能拿到任务,说明整个任务需要被清除了
                    ItemVo<String> item = queue.take();
                    String jobName = item.getData();
                    jobInfoMap.remove(jobName);
                    System.out.println("【" + jobName + "】任务过期了,从缓存中清除");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }

        }
    }

    /**
     * 任务完成后,放入队列,经过 expireTime 时间后,会从整个框架中移除
     *
     * @param jobName
     * @param expireTime
     */
    public void putJob(String jobName, long expireTime) {
        ItemVo<String> item = new ItemVo<>(expireTime, jobName);
        queue.offer(item);
        System.out.println("【" + jobName + "】任务已经放入过期检查缓存,过期时长:" + expireTime);
    }

    static {
        Thread thread = new Thread(new FetchJob());
        thread.setDaemon(true);
        thread.start();
        System.out.println("开启过期检查的守护线程……");
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

内部类 FetchJob 用户检查过期的任务进行清除,putJob() 方法将过期的任务放入队列 queue 进行移除。

使用静态代码块对线程进行设置守护线程并执行。

# 4、框架的主体类

主体类则是 PendingJobPool,这也是业务开发人员主要使用的类。这个类主要负责调度,例如工作(JOB)和任务(TASK)的提交,任务(TASK)的保存,任务(TASK)的并发执行,工作进度的查询接口和任务执行情况的查询等等。

public class PendingJobPool {

    /**
     * 框架运行时的线程数,与机器的CPU数相同
     */
    private static final int THREAD_COUNTS = Runtime.getRuntime().availableProcessors();

    /**
     * 队列,线程池使用,用以存放待处理的任务
     */
    private static BlockingQueue<Runnable> taskQueue = new ArrayBlockingQueue<>(5000);

    /**
     * 线程池,固定大小,有界队列
     */
    private static ExecutorService taskExecutor = new ThreadPoolExecutor(THREAD_COUNTS, THREAD_COUNTS,
            60, TimeUnit.SECONDS, taskQueue);

    /**
     * 工作信息的存放容器
     */
    private static ConcurrentHashMap<String, JobInfo<?>> jobInfoMap = new ConcurrentHashMap<>();

    public static Map<String, JobInfo<?>> getMap() {
        return jobInfoMap;
    }

    /**
     * 以单例模式启动
     */
    private PendingJobPool() {
    }

    private static class JobPoolHolder {
        public static PendingJobPool pool = new PendingJobPool();
    }

    public static PendingJobPool getInstance() {
        return JobPoolHolder.pool;
    }

    /**
     * 对工作中的任务进行包装,提交给线程池使用,并将处理任务的结果,写入缓存以供查询
     *
     * @param <T>
     * @param <R>
     */
    private static class PendingTask<T, R> implements Runnable {

        private JobInfo<R> jobInfo;
        private T processData;

        public PendingTask(JobInfo<R> jobInfo, T processData) {
            this.jobInfo = jobInfo;
            this.processData = processData;
        }

        @Override
        public void run() {
            ITaskProcessor<T, R> taskProcessor = (ITaskProcessor<T, R>) jobInfo.getTaskProcessor();
            TaskResult<R> result = null;
            try {
                result = taskProcessor.taskExecute(processData);
                if (result == null) {
                    result = new TaskResult<>(TaskResultType.EXCEPTION, null, "result is null");
                }
                if (result.getResultType() == null) {
                    if (result.getReason() == null) {
                        result = new TaskResult<>(TaskResultType.EXCEPTION, result.getReturnValue(),
                                "result is null");
                    } else {
                        result = new TaskResult<>(TaskResultType.EXCEPTION, result.getReturnValue(),
                                "result is null, reason: " + result.getReason());
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
                result = new TaskResult<>(TaskResultType.EXCEPTION, null, e.getMessage());
            } finally {
                jobInfo.addTaskResult(result);
            }

        }
    }

    /**
     * 调用者注册工作,如工作名,任务的处理器等等
     *
     * @param jobName
     * @param jobLength
     * @param taskProcessor
     * @param expireTime
     * @param <R>
     */
    public <R> void registerJob(String jobName, int jobLength, ITaskProcessor<?, R> taskProcessor, long expireTime) {
        JobInfo<R> jobInfo = new JobInfo<>(jobName, jobLength, taskProcessor, expireTime);
        if (jobInfoMap.putIfAbsent(jobName, jobInfo) != null) {
            throw new RuntimeException(jobName + "已经注册!");
        }
    }

    /**
     * 根据工作名称检索工作
     *
     * @param jobName
     * @param <R>
     * @return
     */
    private <R> JobInfo<R> getJob(String jobName) {
        JobInfo<R> jobInfo = (JobInfo<R>) jobInfoMap.get(jobName);
        if (null == jobInfo) {
            throw new RuntimeException(jobName + "是非法任务!");
        }
        return jobInfo;
    }

    /**
     * 调用者提交工作中的任务
     *
     * @param jobName
     * @param t
     * @param <T>
     * @param <R>
     */
    public <T, R> void putTask(String jobName, T t) {
        JobInfo<R> jobInfo = getJob(jobName);
        PendingTask<T, R> task = new PendingTask<>(jobInfo, t);
        taskExecutor.execute(task);
    }

    /**
     * 获得工作的整体处理进度
     *
     * @param jobName
     * @param <R>
     * @return
     */
    public <R> String getTaskProgress(String jobName) {
        JobInfo<R> jobInfo = getJob(jobName);
        return jobInfo.getTotalProcess();
    }

    /**
     * 获得每个任务的处理详情
     *
     * @param jobName
     * @param <R>
     * @return
     */
    public <R> List<TaskResult<R>> getTaskResult(String jobName) {
        JobInfo<R> jobInfo = getJob(jobName);
        return jobInfo.getTaskResult();
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

通过单例暴露出实例,用户首先要实现自己业务的 ITaskProcessor 接口,然后使用的时候通过 registerJob() 进行任务的注册,然后调用 putTask() 方法进行执行任务。

流程图:

流程图:

# 5、测试

任务实现类:

public class MyTask implements ITaskProcessor<Integer, Integer> {

    @Override
    public TaskResult<Integer> taskExecute(Integer data) {
        Random r = new Random();
        int flag = r.nextInt(500);
        SleepTool.ms(flag);
        // 正常处理的情况
        if (flag <= 300) {
            Integer returnValue = data + flag;
            return new TaskResult<>(TaskResultType.SUCCESS, returnValue);
        }
        // 处理失败的情况
        else if (flag > 301 && flag <= 400) {
            return new TaskResult<>(TaskResultType.FAILURE, -1, "FAILURE");
        }
        // 发生异常的情况
        else {
            try {
                throw new RuntimeException("异常发生了!!");
            } catch (Exception e) {
                return new TaskResult<>(TaskResultType.EXCEPTION, -1, e.getMessage());
            }
        }
    }

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

一个实际任务类,将数值加上一个随机数,并休眠随机时间。通过随机数,让任务的执行结果分为了3类。

接下来模拟一个应用程序,提交工作和任务,并查询任务进度

public class AppTest {

    private final static String JOB_NAME = "计算数值";
    private final static int JOB_LENGTH = 1000;

    /**
     * 查询任务进度的线程
     */
    private static class QueryResult implements Runnable {

        private PendingJobPool pool;

        public QueryResult(PendingJobPool pool) {
            super();
            this.pool = pool;
        }

        @Override
        public void run() {
            int i = 0;
            while (i < 350) {
                List<TaskResult<String>> taskDetail = pool.getTaskResult(JOB_NAME);
                if (!taskDetail.isEmpty()) {
                    System.out.println(pool.getTaskProgress(JOB_NAME));
                    System.out.println(taskDetail);
                }
                SleepTool.ms(100);
                i++;
            }
        }

    }

    public static void main(String[] args) {
        MyTask myTask = new MyTask();
        PendingJobPool pool = PendingJobPool.getInstance();
        pool.registerJob(JOB_NAME, JOB_LENGTH, myTask, 5);
        Random r = new Random();
        for (int i = 0; i < JOB_LENGTH; i++) {
            pool.putTask(JOB_NAME, r.nextInt(1000));
        }
        Thread t = new Thread(new QueryResult(pool));
        t.start();
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

线程 QueryResult 是专门用来查询任务进度的线程,使用350内的一个循环近似模拟用户端在一段时间内不断的进行进度查询,知道任务超时过期。主线程则进行任务的真正执行。

执行结果:

开启过期检查的守护线程……
Success[2] / Current[2], Total[1000]
[TaskResult{resultType=SUCCESS, returnValue=789, reason='SUCCESS'}, TaskResult{resultType=SUCCESS, returnValue=531, reason='SUCCESS'}]
Success[5] / Current[5], Total[1000]
[TaskResult{resultType=SUCCESS, returnValue=630, reason='SUCCESS'}, TaskResult{resultType=SUCCESS, returnValue=700, reason='SUCCESS'}, TaskResult{resultType=SUCCESS, returnValue=534, reason='SUCCESS'}]
Success[8] / Current[8], Total[1000]
[TaskResult{resultType=SUCCESS, returnValue=880, reason='SUCCESS'}, TaskResult{resultType=SUCCESS, returnValue=482, reason='SUCCESS'}, TaskResult{resultType=SUCCESS, returnValue=404, reason='SUCCESS'}]
...
[TaskResult{resultType=SUCCESS, returnValue=595, reason='SUCCESS'}]
Success[611] / Current[996], Total[1000]
[TaskResult{resultType=EXCEPTION, returnValue=-1, reason='异常发生了!!'}, TaskResult{resultType=SUCCESS, returnValue=818, reason='SUCCESS'}, TaskResult{resultType=EXCEPTION, returnValue=-1, reason='异常发生了!!'}, TaskResult{resultType=EXCEPTION, returnValue=-1, reason='异常发生了!!'}]
【计算数值】任务已经放入过期检查缓存,过期时长:5
Success[611] / Current[1000], Total[1000]
[TaskResult{resultType=EXCEPTION, returnValue=-1, reason='异常发生了!!'}, TaskResult{resultType=EXCEPTION, returnValue=-1, reason='异常发生了!!'}, TaskResult{resultType=EXCEPTION, returnValue=-1, reason='异常发生了!!'}, TaskResult{resultType=EXCEPTION, returnValue=-1, reason='异常发生了!!'}]
【计算数值】任务过期了,从缓存中清除
Exception in thread "Thread-1" java.lang.RuntimeException: 计算数值是非法任务!
	at com.jerry.ch8a.PendingJobPool.getJob(PendingJobPool.java:130)
	at com.jerry.ch8a.PendingJobPool.getTaskResult(PendingJobPool.java:169)
	at com.jerry.ch8a.demo.AppTest$QueryResult.run(AppTest.java:38)
	at java.lang.Thread.run(Thread.java:748)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

可以看出一个完整的执行任务进度的查询,并且在任务过期后是无法再查询的。


上次更新: 2020-08-21 13:12:21(5 小时前)