Python 与尾递归优化

什么是尾递归

有很多时候,使用递归的方式写代码要比迭代更直观一些,以下面的阶乘为例:

def factorial(n):  
    if n == 0:
        return 1
    return factorial(n - 1) * n 

但是这个函数调用,如果展开,会变成如下的形式:

factorial(4)  
factorial(3) * 4  
factorial(2) * 3 * 4  
factorial(1) * 2 * 3 * 4  
factorial(0) * 1 * 2 * 3 * 4  
1 * 1 * 2 * 3 * 4  
1 * 2 * 3 * 4  
2 * 3 * 4  
6 * 4  
24  

可以看出,在每次递归调用的时候,都会产生一个临时变量,导致进程内存占用量增大一些。这样执行一些递归层数比较深的代码时,除了无谓的内存浪费,还有可能导致著名的堆栈溢出错误。

但是如果把上面的函数写成如下形式:

def factorial(n, acc=1):  
    if n == 0:
        return acc
    return factorial(n - 1, n * acc)

我们再脑内展开一下:

factorial(4, 1)  
factorial(3, 4)  
factorial(2, 12)  
factorial(1, 24)  
factorial(0, 24)  
24  

很直观的就可以看出,这次的 factorial 函数在递归调用的时候不会产生一系列逐渐增多的中间变量了,而是将状态保存在 acc 这个变量中。

而这种形式的递归,就叫做尾递归

尾递归的定义顾名思义,函数调用中最后返回的结果是单纯的递归函数调用(或返回结果)就是尾递归

比如代码:

def foo():  
    return foo()

就是尾递归。但是 return 的结果除了递归的函数调用,还包含另外的计算,就不能算作尾递归了,比如:

def foo():  
    return foo() + 1  # return 1 + foo() 也一样

尾递归优化

看上去尾递归这种形式的代码,可以达到和迭代形式的代码同样的效率与空间复杂度。但是当前绝大部分编程语言,函数调用实际上是使用内存中的一个来模拟的。这样在执行尾递归函数时,依然可能会遇到堆栈溢出的问题。

使用栈来实现函数调用的编程语言,在进行一个函数调用时,可以提前分析出这次函数调用传递了多少个参数,以及产生了多少个中间变量,还有参数与变量占用内存的大小(通常包含这个函数调用过程中所有参数、中间变量的这个栈的元素叫做)。这样调用前就把栈顶的指针向前指这么大的内存偏移,这样函数参数、中间变量的内存位置就在调用前分配好了。函数调用完毕时,栈顶指针指回去,就可以瞬间清除掉这次函数调用占用的内存了。并且使用栈来实现函数调用,与绝大部分编程语言的语义相符,比如进行函数调用时,调用者分配的内存空间依然可以使用,变量依然有效。

但是对于递归函数来说,就会遇到一个问题:每次函数调用会让栈的容量增长一点,如果需要进行的递归调用层级很深,这样每进行一次递归调用,即使是不会生成中间变量的尾递归,依然随着函数调用栈的增长,整个进程的内存占用量增长。

但理论上讲,没有产生中间变量来保存状态的尾递归,完全可以复用同一个栈帧来实现所有的递归函数操作。这种形式的代码优化就叫做尾递归优化。

Python 与尾递归优化

对于编译到机器码执行的代码(不管是 AOT 还是 JIT),简单来讲,只要将 call ... ret 指令改为 jump ...,就可以复用同一个栈帧,当然还有很多额外工作需要做。对于解释执行的代码,解释器本身有很多机会可以动态修改栈帧来做尾递归优化。

但是 CPython 的实现并没有支持尾递归优化,并且默认限制了递归调用层数为 1000(通过 sys.getrecursionlimit 函数可以查看)。

不过这并不代表我们没有办法在 Python 中实现尾递归优化。实现尾递归优化的方式中,如果因为某种原因不能直接控制生成的机器代码,也不方便运行时修改栈帧的语言,还有一种方案叫做 Through trampolining

Through trampolining 的大概实现方式是,在递归函数调用时,先插入一个 trampolining(蹦床) 函数。在这个蹦床函数调用来来调用真正的递归函数,并且修改递归函数的函数体,不让它再次进行递归的函数调用,而是直接返回下次递归调用的参数,由蹦床函数来进行下一次递归调用。这样一层一层的递归调用就会变成由蹦床函数一次一次的迭代式函数调用。

并且这种 Through trampolining 的尾递归优化,未必由编程语言本身(编译器 / 运行时)提供,一些灵活性比较强的语言本身就能实现这个过程。比如这里有一段使用 CPython 的实现代码。这段代码全文如下(稍微修改了一下,以便能够在 Python3 下运行):

#!/usr/bin/env python3
# This program shows off a python decorator(
# which implements tail call optimization. It
# does this by throwing an exception if it is 
# it's own grandparent, and catching such 
# exceptions to recall the stack.

import sys

class TailRecurseException(BaseException):  
  def __init__(self, args, kwargs):
    self.args = args
    self.kwargs = kwargs

def tail_call_optimized(g):  
  """
  This function decorates a function with tail call
  optimization. It does this by throwing an exception
  if it is it's own grandparent, and catching such
  exceptions to fake the tail call optimization.

  This function fails if the decorated
  function recurses in a non-tail context.
  """
  def func(*args, **kwargs):
    f = sys._getframe()
    if f.f_back and f.f_back.f_back \
        and f.f_back.f_back.f_code == f.f_code:
      raise TailRecurseException(args, kwargs)
    else:
      while 1:
        try:
          return g(*args, **kwargs)
        except TailRecurseException as e:
          args = e.args
          kwargs = e.kwargs
  func.__doc__ = g.__doc__
  return func

@tail_call_optimized
def factorial(n, acc=1):  
  "calculate a factorial"
  if n == 0:
    return acc
  return factorial(n-1, n*acc)

print(factorial(10000))  
# prints a big, big number,
# but doesn't hit the recursion limit.

@tail_call_optimized
def fib(i, current = 0, next = 1):  
  if i == 0:
    return current
  else:
    return fib(i - 1, next, current + next)

print(fib(10000))  
# also prints a big number,
# but doesn't hit the recursion limit.

仅仅暴露了一个 tail_call_optimized 装饰器,就可以对符合条件的函数进行尾递归优化。

这段代码的实现原理和上面提到的 Through trampolining 相同。但是作为纯 Python 的代码,装饰器并不能修改被装饰的函数体,来让函数只返回下次递归需要的参数,不在进行递归调用。而它的神奇之处就在于,通过装饰器来在每次递归调用函数前,抛出一个异常,然后将这次调用的参数写进异常对象,再在蹦床函数里捕获这个异常,并将参数读出来,然后进行迭代调用。