Ensure that functools.wraps is not called on functools.partial objects in rev_block.

PiperOrigin-RevId: 209524010
This commit is contained in:
A. Unique TensorFlower 2018-08-20 18:49:04 -07:00 committed by TensorFlower Gardener
parent d29759fa53
commit a2a561e7e8

View File

@ -151,9 +151,19 @@ def _rev_block_forward(x1,
return y1, y2
def _safe_wraps(fn):
if isinstance(fn, functools.partial):
# functools.partial objects cannot be wrapped as they are missing the
# necessary properties (__name__, __module__, __doc__).
def passthrough(f):
return f
return passthrough
return functools.wraps(fn)
def _scope_wrap(fn, scope):
@functools.wraps(fn)
@_safe_wraps(fn)
def wrap(*args, **kwargs):
with variable_scope.variable_scope(scope, use_resource=True):
return fn(*args, **kwargs)
@ -430,7 +440,7 @@ def rev_block(x1,
def enable_with_args(dec):
"""A decorator for decorators to enable their usage with or without args."""
@functools.wraps(dec)
@_safe_wraps(dec)
def new_dec(*args, **kwargs):
if len(args) == 1 and not kwargs and callable(args[0]):
# Used as decorator without args
@ -477,7 +487,7 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
tf.gradients).
"""
@functools.wraps(fn)
@_safe_wraps(fn)
def wrapped(*args):
return _recompute_grad(
fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)