為什么有些時(shí)候 Python 中乘法比位運(yùn)算更快
我本來以為我不再會寫水文了,但是突然發(fā)現(xiàn)自己現(xiàn)在也只能勉強(qiáng)寫寫水文才能維持生活這樣子。那就繼續(xù)寫水文吧!
某天,一個(gè)技術(shù)群里老哥提出了這樣一個(gè)問題,為什么在一些情況下,Python 中的簡單乘/除法比位運(yùn)算要慢。
首先秉持著實(shí)事求是的精神,我們先來驗(yàn)證一下:
- In [33]: %timeit 1073741825*2
- 7.47 ns ± 0.0843 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
- In [34]: %timeit 1073741825<<1
- 7.43 ns ± 0.0451 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
- In [35]: %timeit 1073741823<<1
- 7.48 ns ± 0.0621 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
- In [37]: %timeit 1073741823*2
- 7.47 ns ± 0.0564 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
我們發(fā)現(xiàn)幾個(gè)很有趣的現(xiàn)象:
- 在值 x<=2^30 時(shí),乘法比直接位運(yùn)算要快
- 在值 x>2^32 時(shí),乘法顯著慢于位運(yùn)算
這個(gè)現(xiàn)象很有趣,那么這個(gè)現(xiàn)象的 root cause 是什么?實(shí)際上這和 Python 底層的實(shí)現(xiàn)有關(guān)。
簡單聊聊
1. PyLongObject 的實(shí)現(xiàn)
在 Python 2.x 時(shí)期,Python 中將整型分為兩類,一類是 long, 一類是 int 。在 Python3 中這兩者進(jìn)行了合并。目前在 Python3 中這兩者做了合并,僅剩一個(gè) long。
首先來看看 long 這樣一個(gè)數(shù)據(jù)結(jié)構(gòu)底層的實(shí)現(xiàn):
- struct _longobject {
- PyObject_VAR_HEAD
- digit ob_digit[1];
- };
在這里不用關(guān)心,PyObject_VAR_HEAD 的含義,我們只需要關(guān)心 ob_digit 即可。
在這里,ob_digit 是使用了 C99 中的“柔性數(shù)組”來實(shí)現(xiàn)任意長度的整數(shù)的存儲。這里我們可以看一下官方代碼中的文檔:
Long integer representation.The absolute value of a number is equal to SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i) Negative numbers are represented with ob_size < 0; zero is represented by ob_size == 0. In a normalized number, ob_digit[abs(ob_size)-1] (the most significant digit) is never zero. Also, in all cases, for all valid i,0 <= ob_digit[i] <= MASK. The allocation function takes care of allocating extra memory so that ob_digit[0] ... ob_digit[abs(ob_size)-1] are actually available. CAUTION: Generic code manipulating subtypes of PyVarObject has to aware that ints abuse ob_size's sign bit. |
簡而言之,Python 是將一個(gè)十進(jìn)制數(shù)轉(zhuǎn)為 2^(SHIFT) 進(jìn)制數(shù)來進(jìn)行存儲。這里可能不太好了理解。我來舉個(gè)例子,在我的電腦上,SHIFT 為 30 ,假設(shè)現(xiàn)在有整數(shù) 1152921506754330628 ,那么將其轉(zhuǎn)為 2^30 進(jìn)制表示則為: 4*(2^30)^0+2*(2^30)^1+1*(2^30)^2 。那么此時(shí) ob_digit 是一個(gè)含有三個(gè)元素的數(shù)組,其值為 [4,2,1]。
OK,在明白了這樣一些基礎(chǔ)知識后,我們回過頭去看看 Python 中的乘法運(yùn)算。
2. Python 中的乘法運(yùn)算
Python 中的乘法運(yùn)算,分為兩部分,其中關(guān)于大數(shù)的乘法,Python 使用了 Karatsuba 算法1,具體實(shí)現(xiàn)如下:
- static PyLongObject *
- k_mul(PyLongObject *a, PyLongObject *b)
- {
- Py_ssize_t asize = Py_ABS(Py_SIZE(a));
- Py_ssize_t bsize = Py_ABS(Py_SIZE(b));
- PyLongObject *ah = NULL;
- PyLongObject *al = NULL;
- PyLongObject *bh = NULL;
- PyLongObject *bl = NULL;
- PyLongObject *ret = NULL;
- PyLongObject *t1, *t2, *t3;
- Py_ssize_t shift; /* the number of digits we split off */
- Py_ssize_t i;
- /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
- * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl
- * Then the original product is
- * ah*bh*X*X + (k - ah*bh - al*bl)*X + al*bl
- * By picking X to be a power of 2, "*X" is just shifting, and it's
- * been reduced to 3 multiplies on numbers half the size.
- */
- /* We want to split based on the larger number; fiddle so that b
- * is largest.
- */
- if (asize > bsize) {
- t1 = a;
- a = b;
- b = t1;
- i = asize;
- asize = bsize;
- bsize = i;
- }
- /* Use gradeschool math when either number is too small. */
- i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;
- if (asize <= i) {
- if (asize == 0)
- return (PyLongObject *)PyLong_FromLong(0);
- else
- return x_mul(a, b);
- }
- /* If a is small compared to b, splitting on b gives a degenerate
- * case with ah==0, and Karatsuba may be (even much) less efficient
- * than "grade school" then. However, we can still win, by viewing
- * b as a string of "big digits", each of width a->ob_size. That
- * leads to a sequence of balanced calls to k_mul.
- */
- if (2 * asize <= bsize)
- return k_lopsided_mul(a, b);
- /* Split a & b into hi & lo pieces. */
- shift = bsize >> 1;
- if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
- assert(Py_SIZE(ah) > 0); /* the split isn't degenerate */
- if (a == b) {
- bh = ah;
- bl = al;
- Py_INCREF(bh);
- Py_INCREF(bl);
- }
- else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
- /* The plan:
- * 1. Allocate result space (asize + bsize digits: that's always
- * enough).
- * 2. Compute ah*bh, and copy into result at 2*shift.
- * 3. Compute al*bl, and copy into result at 0. Note that this
- * can't overlap with #2.
- * 4. Subtract al*bl from the result, starting at shift. This may
- * underflow (borrow out of the high digit), but we don't care:
- * we're effectively doing unsigned arithmetic mod
- * BASE**(sizea + sizeb), and so long as the *final* result fits,
- * borrows and carries out of the high digit can be ignored.
- * 5. Subtract ah*bh from the result, starting at shift.
- * 6. Compute (ah+al)*(bh+bl), and add it into the result starting
- * at shift.
- */
- /* 1. Allocate result space. */
- ret = _PyLong_New(asize + bsize);
- if (ret == NULL) goto fail;
- #ifdef Py_DEBUG
- /* Fill with trash, to catch reference to uninitialized digits. */
- memset(ret->ob_digit, 0xDF, Py_SIZE(ret) * sizeof(digit));
- #endif
- /* 2. t1 <- ah*bh, and copy into high digits of result. */
- if ((t1 = k_mul(ah, bh)) == NULL) goto fail;
- assert(Py_SIZE(t1) >= 0);
- assert(2*shift + Py_SIZE(t1) <= Py_SIZE(ret));
- memcpy(ret->ob_digit + 2*shift, t1->ob_digit,
- Py_SIZE(t1) * sizeof(digit));
- /* Zero-out the digits higher than the ah*bh copy. */
- i = Py_SIZE(ret) - 2*shift - Py_SIZE(t1);
- if (i)
- memset(ret->ob_digit + 2*shift + Py_SIZE(t1), 0,
- i * sizeof(digit));
- /* 3. t2 <- al*bl, and copy into the low digits. */
- if ((t2 = k_mul(al, bl)) == NULL) {
- Py_DECREF(t1);
- goto fail;
- }
- assert(Py_SIZE(t2) >= 0);
- assert(Py_SIZE(t2) <= 2*shift); /* no overlap with high digits */
- memcpy(ret->ob_digit, t2->ob_digit, Py_SIZE(t2) * sizeof(digit));
- /* Zero out remaining digits. */
- i = 2*shift - Py_SIZE(t2); /* number of uninitialized digits */
- if (i)
- memset(ret->ob_digit + Py_SIZE(t2), 0, i * sizeof(digit));
- /* 4 & 5. Subtract ah*bh (t1) and al*bl (t2). We do al*bl first
- * because it's fresher in cache.
- */
- i = Py_SIZE(ret) - shift; /* # digits after shift */
- (void)v_isub(ret->ob_digit + shift, i, t2->ob_digit, Py_SIZE(t2));
- Py_DECREF(t2);
- (void)v_isub(ret->ob_digit + shift, i, t1->ob_digit, Py_SIZE(t1));
- Py_DECREF(t1);
- /* 6. t3 <- (ah+al)(bh+bl), and add into result. */
- if ((t1 = x_add(ah, al)) == NULL) goto fail;
- Py_DECREF(ah);
- Py_DECREF(al);
- ah = al = NULL;
- if (a == b) {
- t2 = t1;
- Py_INCREF(t2);
- }
- else if ((t2 = x_add(bh, bl)) == NULL) {
- Py_DECREF(t1);
- goto fail;
- }
- Py_DECREF(bh);
- Py_DECREF(bl);
- bh = bl = NULL;
- t3 = k_mul(t1, t2);
- Py_DECREF(t1);
- Py_DECREF(t2);
- if (t3 == NULL) goto fail;
- assert(Py_SIZE(t3) >= 0);
- /* Add t3. It's not obvious why we can't run out of room here.
- * See the (*) comment after this function.
- */
- (void)v_iadd(ret->ob_digit + shift, i, t3->ob_digit, Py_SIZE(t3));
- Py_DECREF(t3);
- return long_normalize(ret);
- fail:
- Py_XDECREF(ret);
- Py_XDECREF(ah);
- Py_XDECREF(al);
- Py_XDECREF(bh);
- Py_XDECREF(bl);
- return NULL;
- }
這里不對 Karatsuba 算法1 的實(shí)現(xiàn)做單獨(dú)解釋,有興趣的朋友可以參考文末的 reference 去了解具體的詳情。
在普通情況下,普通乘法的時(shí)間復(fù)雜度為 n^2 (n 為位數(shù)),而 K 算法的時(shí)間復(fù)雜度為 3n^(log3) ≈ 3n^1.585 ,看起來 K 算法的性能要優(yōu)于普通乘法,那么為什么 Python 不全部使用 K 算法呢?
很簡單,K 算法的優(yōu)勢實(shí)際上要在當(dāng) n 足夠大的時(shí)候,才會對普通乘法形成優(yōu)勢。同時(shí)考慮到內(nèi)存訪問等因素,當(dāng) n 不夠大時(shí),實(shí)際上采用 K 算法的性能將差于直接進(jìn)行乘法。
所以我們來看看 Python 中乘法的實(shí)現(xiàn):
- static PyObject *
- long_mul(PyLongObject *a, PyLongObject *b)
- {
- PyLongObject *z;
- CHECK_BINOP(a, b);
- /* fast path for single-digit multiplication */
- if (Py_ABS(Py_SIZE(a)) <= 1 && Py_ABS(Py_SIZE(b)) <= 1) {
- stwodigits v = (stwodigits)(MEDIUM_VALUE(a)) * MEDIUM_VALUE(b);
- return PyLong_FromLongLong((long long)v);
- }
- z = k_mul(a, b);
- /* Negate if exactly one of the inputs is negative. */
- if (((Py_SIZE(a) ^ Py_SIZE(b)) < 0) && z) {
- _PyLong_Negate(&z);
- if (z == NULL)
- return NULL;
- }
- return (PyObject *)z;
- }
在這里我們看到,當(dāng)兩個(gè)數(shù)皆小于 2^30-1 時(shí),Python 將直接使用普通乘法并返回,否則將使用 K 算法進(jìn)行計(jì)算
這個(gè)時(shí)候,我們來看一下位運(yùn)算的實(shí)現(xiàn),以右移為例:
- static PyObject *
- long_rshift(PyObject *a, PyObject *b)
- {
- Py_ssize_t wordshift;
- digit remshift;
- CHECK_BINOP(a, b);
- if (Py_SIZE(b) < 0) {
- PyErr_SetString(PyExc_ValueError, "negative shift count");
- return NULL;
- }
- if (Py_SIZE(a) == 0) {
- return PyLong_FromLong(0);
- }
- if (divmod_shift(b, &wordshift, &remshift) < 0)
- return NULL;
- return long_rshift1((PyLongObject *)a, wordshift, remshift);
- }
- static PyObject *
- long_rshift1(PyLongObject *a, Py_ssize_t wordshift, digit remshift)
- {
- PyLongObject *z = NULL;
- Py_ssize_t newsize, hishift, i, j;
- digit lomask, himask;
- if (Py_SIZE(a) < 0) {
- /* Right shifting negative numbers is harder */
- PyLongObject *a1, *a2;
- a1 = (PyLongObject *) long_invert(a);
- if (a1 == NULL)
- return NULL;
- a2 = (PyLongObject *) long_rshift1(a1, wordshift, remshift);
- Py_DECREF(a1);
- if (a2 == NULL)
- return NULL;
- z = (PyLongObject *) long_invert(a2);
- Py_DECREF(a2);
- }
- else {
- newsize = Py_SIZE(a) - wordshift;
- if (newsize <= 0)
- return PyLong_FromLong(0);
- hishift = PyLong_SHIFT - remshift;
- lomask = ((digit)1 << hishift) - 1;
- himask = PyLong_MASK ^ lomask;
- z = _PyLong_New(newsize);
- if (z == NULL)
- return NULL;
- for (i = 0, j = wordshift; i < newsize; i++, j++) {
- z->ob_digit[i] = (a->ob_digit[j] >> remshift) & lomask;
- if (i+1 < newsize)
- z->ob_digit[i] |= (a->ob_digit[j+1] << hishift) & himask;
- }
- z = maybe_small_long(long_normalize(z));
- }
- return (PyObject *)z;
- }
在這里我們能看到,在兩側(cè)都是小數(shù)的情況下,位移動算法將比普通乘法,存在更多的內(nèi)存分配等操作。這樣也會回答了我們文初所提到的一個(gè)問題,“為什么一些時(shí)候乘法比位運(yùn)算更快”。
總結(jié)本文差不多就到這里了,實(shí)際上通過這次分析我們能得到一些很有趣但是也很冷門的知識。實(shí)際上我們目前看到這樣一個(gè)結(jié)果,是 Python 對于我們常見且高頻的操作所做的一個(gè)特定的設(shè)計(jì)。而這也提醒我們,Python 實(shí)際上對于很多操作都存在自己內(nèi)建的設(shè)計(jì)哲學(xué),在日常使用的時(shí)候,其余語言的經(jīng)驗(yàn),可能無法復(fù)用。