1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinTask; import java.util.concurrent.RecursiveTask; import java.util.function.Function; import java.util.stream.LongStream;
public class ForkJoinSumCalculator extends RecursiveTask<Long> { private final long[] numbers; private final int start; private final int end;
public static final long THRESHOLD = 10_000;
public ForkJoinSumCalculator(long[] numbers) { this(numbers, 0, numbers.length); }
private ForkJoinSumCalculator(long[] numbers, int start, int end) { this.numbers = numbers; this.start = start; this.end = end; }
@Override protected Long compute() { int length = end - start; if (length <= THRESHOLD) { return computeSequentially(); } ForkJoinSumCalculator leftTask = new ForkJoinSumCalculator(numbers, start, start + length / 2); leftTask.fork(); ForkJoinSumCalculator rightTask = new ForkJoinSumCalculator(numbers, start + length / 2, end); Long rightResult = rightTask.compute(); Long leftResult = leftTask.join(); return leftResult + rightResult; }
private Long computeSequentially() { long sum = 0; for (int i = start; i < end; i++) { sum += numbers[i]; } return sum; }
public static long forJoinSum(long n) { long[] numbers = LongStream.rangeClosed(1, n).toArray(); ForkJoinTask<Long> task = new ForkJoinSumCalculator(numbers); return new ForkJoinPool().invoke(task); }
public static long parallel(long n) { return LongStream.rangeClosed(1, n).parallel().reduce(0, Long::sum); }
public static void main(String[] args) { measureSumPerf(ForkJoinSumCalculator::forJoinSum, 1000000L); measureSumPerf(ForkJoinSumCalculator::parallel, 1000000L);
}
public static long measureSumPerf(Function<Long, Long> adder, Long n) { long max = Long.MAX_VALUE; for (int i = 0; i < 10; i++) { long start = System.nanoTime(); Long result = adder.apply(n); long time = (System.nanoTime() - start) / 1_000_000; if (time < max) { max = time; } } System.out.println(max); return max; } }
|