[tf.function] In graph mode, preserve nested attributes for func-valued attrs.

This allows users to create a `tf.function` "f" that has attributes, using the private `function.defun_with_attributes()` method, set "f" as the attribute of another op (e.g. a MapDataset op), and preserve the attributes of the original "f".

PiperOrigin-RevId: 288381945
Change-Id: I9dcf780251e610a4ed21a405372695b9b7671b69
This commit is contained in:
Derek Murray 2020-01-06 14:55:29 -08:00 committed by TensorFlower Gardener
parent 6ed5c4b7f7
commit 0733a86c39
3 changed files with 43 additions and 5 deletions

View File

@ -32,9 +32,9 @@ from six.moves import map
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tfe
from tensorflow.python import _pywrap_utils
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
@ -1969,6 +1969,14 @@ class ConcreteFunction(object):
outputs_list, expand_composites=True)
return ret
@property
def _as_name_attr_list(self):
"""Returns a `NameAttrList` representing this function."""
ret = attr_value_pb2.NameAttrList(name=self.name)
for name, value in self._attrs.items():
ret.attr[name].CopyFrom(value)
return ret
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)

View File

@ -25,9 +25,9 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import tf_logging as logging
@ -217,12 +217,14 @@ def _MakeFunc(v, arg_name):
"""Ensure v is a func."""
if isinstance(v, attr_value_pb2.NameAttrList):
return v
fn_attr = attr_value_pb2.NameAttrList()
if isinstance(v, compat.bytes_or_text_types):
fn_attr.name = v
fn_attr = attr_value_pb2.NameAttrList(name=v)
elif hasattr(v, "add_to_graph"):
v.add_to_graph(ops.get_default_graph())
fn_attr.name = v.name
if hasattr(v, "_as_name_attr_list"):
fn_attr = v._as_name_attr_list # pylint: disable=protected-access
else:
fn_attr = attr_value_pb2.NameAttrList(name=v.name)
else:
raise TypeError("Don't know how to convert {} to a func for "
"argument {}".format(v, arg_name))

View File

@ -20,13 +20,16 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.util import compat
class OpDefLibraryTest(test_util.TensorFlowTestCase):
@ -407,6 +410,31 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
self.assertEqual(str(cm.exception),
"Don't know how to convert 3 to a func for argument f")
def testAttrFuncWithFuncWithAttrs(self):
with ops.Graph().as_default():
@eager_function.defun_with_attributes(
input_signature=(tensor_spec.TensorSpec(None, dtypes.float32),),
autograph=False,
attributes={"_dummy_attr": 15})
def fn(x):
return 2 + x
concrete_fn = fn.get_concrete_function()
op = op_def_library.apply_op("FuncAttr", f=concrete_fn, name="t")
self.assertProtoEquals("""
name: 't' op: 'FuncAttr'
attr {
key: 'f'
value {
func {
name: '%s'
attr { key: "_dummy_attr" value { i: 15 } }
}
}
}
""" % compat.as_str(concrete_fn.name), op.node_def)
def testAttrFuncList(self):
with ops.Graph().as_default():
@function.Defun(dtypes.float32, func_name="MyFn")