[tf.function] In graph mode, preserve nested attributes for func-valued attrs.
This allows users to create a `tf.function` "f" that has attributes, using the private `function.defun_with_attributes()` method, set "f" as the attribute of another op (e.g. a MapDataset op), and preserve the attributes of the original "f". PiperOrigin-RevId: 288381945 Change-Id: I9dcf780251e610a4ed21a405372695b9b7671b69
This commit is contained in:
parent
6ed5c4b7f7
commit
0733a86c39
@ -32,9 +32,9 @@ from six.moves import map
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -1969,6 +1969,14 @@ class ConcreteFunction(object):
|
||||
outputs_list, expand_composites=True)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def _as_name_attr_list(self):
|
||||
"""Returns a `NameAttrList` representing this function."""
|
||||
ret = attr_value_pb2.NameAttrList(name=self.name)
|
||||
for name, value in self._attrs.items():
|
||||
ret.attr[name].CopyFrom(value)
|
||||
return ret
|
||||
|
||||
|
||||
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
|
||||
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
|
||||
|
@ -25,9 +25,9 @@ from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import tensor_pb2
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import op_callbacks
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -217,12 +217,14 @@ def _MakeFunc(v, arg_name):
|
||||
"""Ensure v is a func."""
|
||||
if isinstance(v, attr_value_pb2.NameAttrList):
|
||||
return v
|
||||
fn_attr = attr_value_pb2.NameAttrList()
|
||||
if isinstance(v, compat.bytes_or_text_types):
|
||||
fn_attr.name = v
|
||||
fn_attr = attr_value_pb2.NameAttrList(name=v)
|
||||
elif hasattr(v, "add_to_graph"):
|
||||
v.add_to_graph(ops.get_default_graph())
|
||||
fn_attr.name = v.name
|
||||
if hasattr(v, "_as_name_attr_list"):
|
||||
fn_attr = v._as_name_attr_list # pylint: disable=protected-access
|
||||
else:
|
||||
fn_attr = attr_value_pb2.NameAttrList(name=v.name)
|
||||
else:
|
||||
raise TypeError("Don't know how to convert {} to a func for "
|
||||
"argument {}".format(v, arg_name))
|
||||
|
@ -20,13 +20,16 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.python.eager import function as eager_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import op_def_library
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class OpDefLibraryTest(test_util.TensorFlowTestCase):
|
||||
@ -407,6 +410,31 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(str(cm.exception),
|
||||
"Don't know how to convert 3 to a func for argument f")
|
||||
|
||||
def testAttrFuncWithFuncWithAttrs(self):
|
||||
with ops.Graph().as_default():
|
||||
@eager_function.defun_with_attributes(
|
||||
input_signature=(tensor_spec.TensorSpec(None, dtypes.float32),),
|
||||
autograph=False,
|
||||
attributes={"_dummy_attr": 15})
|
||||
def fn(x):
|
||||
return 2 + x
|
||||
|
||||
concrete_fn = fn.get_concrete_function()
|
||||
|
||||
op = op_def_library.apply_op("FuncAttr", f=concrete_fn, name="t")
|
||||
self.assertProtoEquals("""
|
||||
name: 't' op: 'FuncAttr'
|
||||
attr {
|
||||
key: 'f'
|
||||
value {
|
||||
func {
|
||||
name: '%s'
|
||||
attr { key: "_dummy_attr" value { i: 15 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
""" % compat.as_str(concrete_fn.name), op.node_def)
|
||||
|
||||
def testAttrFuncList(self):
|
||||
with ops.Graph().as_default():
|
||||
@function.Defun(dtypes.float32, func_name="MyFn")
|
||||
|
Loading…
Reference in New Issue
Block a user