Add build_op_info to support defun caller op for TF internal use cases.

PiperOrigin-RevId: 220578191
This commit is contained in:
Yanhui Liang 2018-11-07 21:04:39 -08:00 committed by TensorFlower Gardener
parent b3492ab89a
commit f05d5acddd
3 changed files with 64 additions and 1 deletions
tensorflow/python/saved_model

View File

@ -22,5 +22,6 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.saved_model.utils_impl import build_tensor_info
from tensorflow.python.saved_model.utils_impl import build_tensor_info_from_op
from tensorflow.python.saved_model.utils_impl import get_tensor_from_tensor_info
# pylint: enable=unused-import

View File

@ -20,10 +20,12 @@ from __future__ import print_function
import os
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import constants
from tensorflow.python.util import compat
@ -42,7 +44,7 @@ from tensorflow.python.util.tf_export import tf_export
"library as tf.compat.v1.saved_model.utils.build_tensor_info or "
"tf.compat.v1.saved_model.build_tensor_info.")
def build_tensor_info(tensor):
"""Utility function to build TensorInfo proto.
"""Utility function to build TensorInfo proto from a Tensor.
Args:
tensor: Tensor or SparseTensor whose name, dtype and shape are used to
@ -64,6 +66,41 @@ def build_tensor_info(tensor):
return tensor_info
def build_tensor_info_from_op(op):
"""Utility function to build TensorInfo proto from an Op.
Note that this function should be used with caution. It is strictly restricted
to TensorFlow internal use-cases only. Please make sure you do need it before
using it.
This utility function overloads the TensorInfo proto by setting the name to
the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage
is for the Op of the call site for the defunned function:
```python
@function.defun
def some_vairable_initialiation_fn(value_a, value_b):
a = value_a
b = value_b
value_a = constant_op.constant(1, name="a")
value_b = constant_op.constant(2, name="b")
op_info = utils.build_op_info(
some_vairable_initialiation_fn(value_a, value_b))
```
Args:
op: An Op whose name is used to build the TensorInfo. The name that points
to the Op could be fetched at run time in the Loader session.
Returns:
A TensorInfo protocol buffer constructed based on the supplied argument.
"""
return meta_graph_pb2.TensorInfo(
dtype=types_pb2.DT_INVALID,
tensor_shape=tensor_shape.unknown_shape().as_proto(),
name=op.name)
@tf_export(v1=["saved_model.get_tensor_from_tensor_info",
"saved_model.utils.get_tensor_from_tensor_info"])
@deprecation.deprecated(

View File

@ -19,16 +19,41 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import types_pb2
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import utils
class UtilsTest(test.TestCase):
def testBuildTensorInfoOp(self):
x = constant_op.constant(1, name="x")
y = constant_op.constant(2, name="y")
z = control_flow_ops.group([x, y], name="op_z")
z_op_info = utils.build_tensor_info_from_op(z)
self.assertEqual("op_z", z_op_info.name)
self.assertEqual(types_pb2.DT_INVALID, z_op_info.dtype)
self.assertEqual(0, len(z_op_info.tensor_shape.dim))
def testBuildTensorInfoDefunOp(self):
@function.defun
def my_init_fn(x, y):
self.x_var = x
self.y_var = y
x = constant_op.constant(1, name="x")
y = constant_op.constant(2, name="y")
init_op_info = utils.build_tensor_info_from_op(my_init_fn(x, y))
self.assertEqual("PartitionedFunctionCall", init_op_info.name)
self.assertEqual(types_pb2.DT_INVALID, init_op_info.dtype)
self.assertEqual(0, len(init_op_info.tensor_shape.dim))
def testBuildTensorInfoDense(self):
x = array_ops.placeholder(dtypes.float32, 1, name="x")
x_tensor_info = utils.build_tensor_info(x)