Python 標(biāo)準(zhǔn)庫中非常有用的裝飾器
眾所周知,Python 語言靈活、簡潔,對程序員友好,但在性能上有點不太令人滿意,這一點通過一個遞歸的求斐波那契額函數(shù)就可以說明:
- def fib(n):
- if n <= 1:
- return n
- return fib(n - 1) + fib(n - 2)
在我的 MBP 上計算 fib(40) 花費了 33 秒:
- import time
- def main():
- start = time.time()
- result = fib(40)
- end = time.time()
- cost = end - start
- print(f"{result = } {cost = :.4f}")
- if __name__ == '__main__':
- main()
但是,假如使用標(biāo)準(zhǔn)庫中的這個裝飾器,那結(jié)果完全不一樣
- from functools import lru_cache
- @lru_cache
- def fib(n):
- if n <= 1:
- return n
- return fib(n - 1) + fib(n - 2)
這次的結(jié)果是 0 秒,你沒看錯,我保留了 4 位小數(shù),后面的忽略了。
提升了多少倍?我已經(jīng)計算不出來了。
為什么 lru_cache 裝飾器這么牛逼,它到底做了什么事情?今天就來聊一聊這個最有用的裝飾器。
如果看過計算機操作系統(tǒng)的話,你對 LRU 一定不會陌生,這就是著名的最近最久未使用緩存淘汰算法。
而 lru_cache 就是這個算法的具體實現(xiàn)。(這個算法可是面試經(jīng)??嫉呐?,有的面試官要求現(xiàn)場手寫代碼)
現(xiàn)在,我們來看一個 lru_cache 的源代碼,其中的英文注釋,我已經(jīng)為你翻譯為中文:
- def lru_cache(maxsize=128, typed=False):
- """LRU 緩存裝飾器
- 如果 *maxsize* 是 None, 將不會淘汰緩存,緩存大小也不做限制
- 如果 *typed* 是 True, 不同類型的參數(shù)將獨立做緩存,比如 f(3.0) and f(3) 將認(rèn)為是不同的函數(shù)調(diào)用而緩存在兩個緩存節(jié)點上。
- 函數(shù)的參數(shù)必須可以被 hash
- 查看緩存信息使用的是命名元組 (hits, misses, maxsize, currsize)
- 查看緩存信息:user_func.cache_info(). 清理緩存信息:user_func.cache_clear().
- LRU 算法: http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
- """
- # lru_cache 的內(nèi)部實現(xiàn)是線程安全的
- if isinstance(maxsize, int):
- # 負(fù)數(shù)轉(zhuǎn)換為 0
- if maxsize < 0:
- maxsize = 0
- elif callable(maxsize) and isinstance(typed, bool):
- #如果被裝飾的函數(shù)(user_function)直接通過 maxsize 參數(shù)傳入
- user_function, maxsize = maxsize, 128
- wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
- return update_wrapper(wrapper, user_function)
- elif maxsize is not None:
- raise TypeError(
- 'Expected first argument to be an integer, a callable, or None')
- def decorating_function(user_function):
- wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
- return update_wrapper(wrapper, user_function)
- return decorating_function
這里面有兩個參數(shù),一個是 maxsize,表示緩存的大小,當(dāng)傳入負(fù)數(shù)時,自動設(shè)置為 0,如果不傳入 maxsize,或者設(shè)置為 None,表示緩存沒有大小限制,此時沒有緩存淘汰。還有一個是 type,當(dāng) type 傳入 True 時,不同的參數(shù)類型會當(dāng)作不同的 key 存到緩存當(dāng)中。
接下來,lru_cache 的核心在這個函數(shù)上 _lru_cache_wrapper,建議有感情的閱讀、背誦并默寫。我們來看下它的源代碼
- def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
- # 所有 lru cache 實例共享的常量:
- sentinel = object() # 用來表示緩存未命中的唯一對象
- make_key = _make_key # build a key from the function arguments
- PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
- cache = {}
- hits = misses = 0
- full = False
- cache_get = cache.get # 綁定函數(shù)來獲取緩存中 key 的值
- cache_len = cache.__len__ # 綁定函數(shù)獲取緩存大小
- lock = RLock() # 因為鏈表上的更新是線程不安全的
- root = [] # 循環(huán)雙向鏈表的根節(jié)點
- root[:] = [root, root, None, None] # 初始化根節(jié)點的前后指針都指向它自己
- if maxsize == 0:
- def wrapper(*args, **kwds):
- # 沒有緩存,僅更新統(tǒng)計信息
- nonlocal misses
- misses += 1
- result = user_function(*args, **kwds)
- return result
- elif maxsize is None:
- def wrapper(*args, **kwds):
- # 僅僅排序,不考慮排序和緩存大小限制
- nonlocal hits, misses
- key = make_key(args, kwds, typed)
- result = cache_get(key, sentinel)
- if result is not sentinel:
- hits += 1
- return result
- misses += 1
- result = user_function(*args, **kwds)
- cache[key] = result
- return result
- else:
- def wrapper(*args, **kwds):
- # 大小有限制,并跟蹤最近使用的緩存
- nonlocal root, hits, misses, full
- key = make_key(args, kwds, typed)
- with lock:
- link = cache_get(key)
- if link is not None:
- # 緩存命中,將命中的緩存移動到循環(huán)雙向鏈表的頭部
- link_prev, link_next, _key, result = link
- link_prev[NEXT] = link_next
- link_next[PREV] = link_prev
- last = root[PREV]
- last[NEXT] = root[PREV] = link
- link[PREV] = last
- link[NEXT] = root
- hits += 1
- return result
- misses += 1
- result = user_function(*args, **kwds)
- with lock:
- if key in cache:
- # 走到這里說明 key 已經(jīng)放在了緩存,且鎖已經(jīng)釋放了,鏈表已經(jīng)更新了,這里什么也不需要做了,最后只需要返回計算的結(jié)果就可以了。
- pass
- elif full:
- # 如果緩存滿了, 使用最老的根節(jié)點來存儲新節(jié)點就可以了,鏈表上不需要刪除(是不是很聰明)
- oldroot = root
- oldroot[KEY] = key
- oldroot[RESULT] = result
- root = oldroot[NEXT]
- oldkey = root[KEY]
- oldresult = root[RESULT]
- root[KEY] = root[RESULT] = None
- # 最后,我們需要從緩存中清除這個 key,因為它已經(jīng)無效了。
- del cache[oldkey]
- # 新值放入緩存
- cache[key] = oldroot
- else:
- # 如果沒有滿,將新的結(jié)果放入循環(huán)雙向鏈表的頭部
- last = root[PREV]
- link = [last, root, key, result]
- last[NEXT] = root[PREV] = cache[key] = link
- # 使用 cache_len 綁定方法而不是 len() 函數(shù),后者可能會被包裝在 lru_cache 本身中
- full = (cache_len() >= maxsize)
- return result
- def cache_info():
- """報告緩存統(tǒng)計信息"""
- with lock:
- return _CacheInfo(hits, misses, maxsize, cache_len())
- def cache_clear():
- """清理緩存信息"""
- nonlocal hits, misses, full
- with lock:
- cache.clear()
- root[:] = [root, root, None, None]
- hits = misses = 0
- full = False
- wrapper.cache_info = cache_info
- wrapper.cache_clear = cache_clear
- return wrapper
如果我寫的注釋你都看明白了,那也不用看我下面的廢話了,如果還有點不太明白,我啰嗦幾句,也許你就明白了。
第一、所謂緩存,用的仍然是內(nèi)存,為了快速存取,用的就是一個 hash 表,也就是 Python 的字典,都是在內(nèi)存里的操作。
- cache = {}
第二、如果 maxsize == 0,就相當(dāng)于沒有使用緩存,每調(diào)用一次,未命中數(shù)就 + 1,代碼邏輯是這樣的:
- def wrapper(*args, **kwds):
- nonlocal misses
- misses += 1 # 未命中數(shù)
- result = user_function(*args, **kwds)
- return result
第三、如果 maxsize == None,相當(dāng)于緩存無限制,也就不需要考慮淘汰,這個實現(xiàn)非常簡單,我們直接在函數(shù)中用一個字典就可以實現(xiàn),比如說:
- cache = {}
- def fib(n):
- if n in cache:
- return cache[n]
- if n <= 1:
- return n
- result = fib(n - 1) + fib(n - 2)
- cache[n] = result
- return result
運行時間:
理解了這一點,在裝飾器中,這段邏輯就不難看懂:
- def wrapper(*args, **kwds):
- nonlocal hits, misses
- key = make_key(args, kwds, typed)
- result = cache_get(key, sentinel)
- if result is not sentinel:
- hits += 1
- return result
- misses += 1
- result = user_function(*args, **kwds)
- cache[key] = result
- return result
第四、真正的緩存淘汰算法。
為了實現(xiàn)緩存(鍵值對)的淘汰,我們需要對緩存按時間進行排序,這就需要用到鏈表,鏈表的頭部是最新插入的,尾部是最老插入的,當(dāng)緩存數(shù)量已經(jīng)達到最大值時,我們刪除最久未使用的鏈尾節(jié)點,為了不刪除鏈尾,我們可以使用循環(huán)鏈表,當(dāng)緩存滿了,直接更新鏈尾節(jié)點賦值為新節(jié)點,并把它做為新的鏈頭就可以了。
當(dāng)緩存命中時,我們需要把這個節(jié)點移動到鏈表的頭部,保證鏈表的頭部是最近經(jīng)常使用的,為了移動方便,我們需要雙向鏈表。
雙向循環(huán)鏈表在 Python 中實現(xiàn),可以簡單的這么寫:
- PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
- root = [] # root of the circular doubly linked list
- root[:] = [root, root, None, None] # initialize by pointing to self
可能有些朋友看不懂最后那行代碼:root[:] = [root, root, None, None],畫個圖你就理解了:
這些箭頭指向的都是節(jié)點的內(nèi)存地址,隨著節(jié)點的增多,就是這個樣子的:
對比這個圖,再看源代碼,就很容易看懂了。尤其是這塊的代碼邏輯,是面試??嫉闹攸c,如果你能手寫出這樣線程安全的 LRU 緩存淘汰算法,那無疑是非常優(yōu)秀的。
其他 LRU 算法的實現(xiàn)
其他關(guān)于 LRU 算法的實現(xiàn),我自己寫了兩個,可以看這里:
LRU 緩存淘汰算法-雙鏈表+hash 表[1]
LRU 緩存淘汰算法-Python 有序字典[2]
最后的話
裝飾器 lru_cache 的作用就是把函數(shù)的計算機結(jié)果保存下來,下次用的時候可以直接從 hash 表中取出,避免重復(fù)計算從而提升效率,簡單點的,直接在函數(shù)中使用個字典就搞定了,復(fù)雜點的,請看 lru_cache 的代碼實現(xiàn)。另一方面,遞歸函數(shù)慢的一個主要原因就是重復(fù)計算。
Python 標(biāo)準(zhǔn)庫的源碼,是學(xué)習(xí)編程最有營養(yǎng)的原料,當(dāng)你有好奇心時,不妨去窺探一下源碼,相信你有定會有新的收獲。今天的分享就到這里,如果有收獲的話,請點贊、在看、轉(zhuǎn)發(fā)、關(guān)注,感謝你的支持。
參考資料
[1]
LRU 緩存淘汰算法-雙鏈表+hash 表: https://github.com/somenzz/geekbang/blob/master/algorthms/lru_use_link_table.py
[2]
LRU 緩存淘汰算法-Python 有序字典: https://github.com/somenzz/geekbang/blob/master/algorthms/lru_use_ordered_dict.py