为什么sum比inject(:+)快得多?


129

所以我在Ruby 2.4.0中运行一些基准测试,并意识到

(1...1000000000000000000000000000000).sum

立即计算,而

(1...1000000000000000000000000000000).inject(:+)

花了这么长时间,我才中止了手术。我的印象是这Range#sum是一个别名,Range#inject(:+)但似乎并非如此。那么sum,它是如何工作的,为什么比它那么快inject(:+)呢?

注意:的文档Enumerable#sum(由实现Range)没有说明任何有关延迟评估的内容或类似内容。

Answers:


227

简短答案

对于整数范围:

  • Enumerable#sum 退货 (range.max-range.min+1)*(range.max+range.min)/2
  • Enumerable#inject(:+) 遍历每个元素。

理论

1和之间的整数之和n称为三角数,等于n*(n+1)/2

之间的整数之nm是的三角形数m减去的三角形数n-1,它等于m*(m+1)/2-n*(n-1)/2并且可以写(m-n+1)*(m+n)/2

Ruby 2.4中的Enumerable#sum

该属性Enumerable#sum用于整数范围:

if (RTEST(rb_range_values(obj, &beg, &end, &excl))) {
    if (!memo.block_given && !memo.float_value &&
            (FIXNUM_P(beg) || RB_TYPE_P(beg, T_BIGNUM)) &&
            (FIXNUM_P(end) || RB_TYPE_P(end, T_BIGNUM))) { 
        return int_range_sum(beg, end, excl, memo.v);
    } 
}

int_range_sum 看起来像这样:

VALUE a;
a = rb_int_plus(rb_int_minus(end, beg), LONG2FIX(1));
a = rb_int_mul(a, rb_int_plus(end, beg));
a = rb_int_idiv(a, LONG2FIX(2));
return rb_int_plus(init, a);

等效于:

(range.max-range.min+1)*(range.max+range.min)/2

前述平等!

复杂

非常感谢@k_g和@ Hynek-Pichi-Vychodil!

(1...1000000000000000000000000000000).sum 需要三个加法,一个乘法,一个减法和一个除法。

它是恒定数量的运算,但是Enumerable#sum对于整数范围,乘法是O((log n)²),所以O((log n)²)也是如此。

注入

(1...1000000000000000000000000000000).inject(:+)

需要添加999999999999999999999999999998!

加法为O(log n),所以Enumerable#injectO(n log n)。

随着1E30作为输入,inject与有去无回。太阳不久就会爆炸!

测试

检查是否添加了Ruby Integers很容易:

module AdditionInspector
  def +(b)
    puts "Calculating #{self}+#{b}"
    super
  end
end

class Integer
  prepend AdditionInspector
end

puts (1..5).sum
#=> 15

puts (1..5).inject(:+)
# Calculating 1+2
# Calculating 3+3
# Calculating 6+4
# Calculating 10+5
#=> 15

确实,根据enum.c评论:

Enumerable#summethod可能不尊重方法的重新定义"+" 方法,例如Integer#+


17
这是一个非常好的优化,因为如果使用正确的公式,那么计算一个范围内的数字之和是微不足道的,而如果反复进行计算则会很麻烦。就像尝试将乘法实现为一系列加法运算一样。
tadman

因此,性能提升仅适用于n+1范围吗?我没有安装2.4或进行自我测试,但是其他可枚举对象是通过基本加法处理的,因为它们将inject(:+)减去从符号到proc的开销。
Engineersmnky

8
读者,请回想一下您的高中数学,它n, n+1, n+2, .., m构成一个总和等于的算术级数(m-n+1)*(m+n)/2。同样,几何级数之n, (α^1)n, (α^2)n, (α^3)n, ... , (α^m)n。可以从封闭形式的表达式中计算得出。
卡里·斯沃夫兰

4
\ begin {nitpick}可枚举的#sum为O((log n)^ 2),而inject为O(n log n),当您的数字不受限制时。\ end {nitpick}
k_g

6
@EliSadoff:这意味着真正的大数字。这意味着不适合架构字的数字,即不能通过CPU内核中的一条指令和一项操作来计算。大小N的数量可以用log_2 N位编码,因此加法是O(logN)运算,乘法是O((logN)^ 2),但可以是O((logN)^ 1.585)(Karasuba)甚至O(logN *日志(10即)*日志(LOG(LOGN))(FFT)。
希内克-Pichi- Vychodil
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.