Ensure that functools.wraps is not called on functools.partial objects in rev_block.
PiperOrigin-RevId: 209524010
This commit is contained in:
parent
d29759fa53
commit
a2a561e7e8
@ -151,9 +151,19 @@ def _rev_block_forward(x1,
|
|||||||
return y1, y2
|
return y1, y2
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_wraps(fn):
|
||||||
|
if isinstance(fn, functools.partial):
|
||||||
|
# functools.partial objects cannot be wrapped as they are missing the
|
||||||
|
# necessary properties (__name__, __module__, __doc__).
|
||||||
|
def passthrough(f):
|
||||||
|
return f
|
||||||
|
return passthrough
|
||||||
|
return functools.wraps(fn)
|
||||||
|
|
||||||
|
|
||||||
def _scope_wrap(fn, scope):
|
def _scope_wrap(fn, scope):
|
||||||
|
|
||||||
@functools.wraps(fn)
|
@_safe_wraps(fn)
|
||||||
def wrap(*args, **kwargs):
|
def wrap(*args, **kwargs):
|
||||||
with variable_scope.variable_scope(scope, use_resource=True):
|
with variable_scope.variable_scope(scope, use_resource=True):
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
@ -430,7 +440,7 @@ def rev_block(x1,
|
|||||||
def enable_with_args(dec):
|
def enable_with_args(dec):
|
||||||
"""A decorator for decorators to enable their usage with or without args."""
|
"""A decorator for decorators to enable their usage with or without args."""
|
||||||
|
|
||||||
@functools.wraps(dec)
|
@_safe_wraps(dec)
|
||||||
def new_dec(*args, **kwargs):
|
def new_dec(*args, **kwargs):
|
||||||
if len(args) == 1 and not kwargs and callable(args[0]):
|
if len(args) == 1 and not kwargs and callable(args[0]):
|
||||||
# Used as decorator without args
|
# Used as decorator without args
|
||||||
@ -477,7 +487,7 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
|
|||||||
tf.gradients).
|
tf.gradients).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@functools.wraps(fn)
|
@_safe_wraps(fn)
|
||||||
def wrapped(*args):
|
def wrapped(*args):
|
||||||
return _recompute_grad(
|
return _recompute_grad(
|
||||||
fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
|
fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user