Ensure that functools.wraps is not called on functools.partial objects in rev_block.
PiperOrigin-RevId: 209524010
This commit is contained in:
parent
d29759fa53
commit
a2a561e7e8
@ -151,9 +151,19 @@ def _rev_block_forward(x1,
|
||||
return y1, y2
|
||||
|
||||
|
||||
def _safe_wraps(fn):
|
||||
if isinstance(fn, functools.partial):
|
||||
# functools.partial objects cannot be wrapped as they are missing the
|
||||
# necessary properties (__name__, __module__, __doc__).
|
||||
def passthrough(f):
|
||||
return f
|
||||
return passthrough
|
||||
return functools.wraps(fn)
|
||||
|
||||
|
||||
def _scope_wrap(fn, scope):
|
||||
|
||||
@functools.wraps(fn)
|
||||
@_safe_wraps(fn)
|
||||
def wrap(*args, **kwargs):
|
||||
with variable_scope.variable_scope(scope, use_resource=True):
|
||||
return fn(*args, **kwargs)
|
||||
@ -430,7 +440,7 @@ def rev_block(x1,
|
||||
def enable_with_args(dec):
|
||||
"""A decorator for decorators to enable their usage with or without args."""
|
||||
|
||||
@functools.wraps(dec)
|
||||
@_safe_wraps(dec)
|
||||
def new_dec(*args, **kwargs):
|
||||
if len(args) == 1 and not kwargs and callable(args[0]):
|
||||
# Used as decorator without args
|
||||
@ -477,7 +487,7 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
|
||||
tf.gradients).
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
@_safe_wraps(fn)
|
||||
def wrapped(*args):
|
||||
return _recompute_grad(
|
||||
fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
|
||||
|
Loading…
Reference in New Issue
Block a user