Add build_op_info to support defun caller op for TF internal use cases.
PiperOrigin-RevId: 220578191
This commit is contained in:
parent
b3492ab89a
commit
f05d5acddd
tensorflow/python/saved_model
@ -22,5 +22,6 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.saved_model.utils_impl import build_tensor_info
|
||||
from tensorflow.python.saved_model.utils_impl import build_tensor_info_from_op
|
||||
from tensorflow.python.saved_model.utils_impl import get_tensor_from_tensor_info
|
||||
# pylint: enable=unused-import
|
||||
|
@ -20,10 +20,12 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.util import compat
|
||||
@ -42,7 +44,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
"library as tf.compat.v1.saved_model.utils.build_tensor_info or "
|
||||
"tf.compat.v1.saved_model.build_tensor_info.")
|
||||
def build_tensor_info(tensor):
|
||||
"""Utility function to build TensorInfo proto.
|
||||
"""Utility function to build TensorInfo proto from a Tensor.
|
||||
|
||||
Args:
|
||||
tensor: Tensor or SparseTensor whose name, dtype and shape are used to
|
||||
@ -64,6 +66,41 @@ def build_tensor_info(tensor):
|
||||
return tensor_info
|
||||
|
||||
|
||||
def build_tensor_info_from_op(op):
|
||||
"""Utility function to build TensorInfo proto from an Op.
|
||||
|
||||
Note that this function should be used with caution. It is strictly restricted
|
||||
to TensorFlow internal use-cases only. Please make sure you do need it before
|
||||
using it.
|
||||
|
||||
This utility function overloads the TensorInfo proto by setting the name to
|
||||
the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage
|
||||
is for the Op of the call site for the defunned function:
|
||||
```python
|
||||
@function.defun
|
||||
def some_vairable_initialiation_fn(value_a, value_b):
|
||||
a = value_a
|
||||
b = value_b
|
||||
|
||||
value_a = constant_op.constant(1, name="a")
|
||||
value_b = constant_op.constant(2, name="b")
|
||||
op_info = utils.build_op_info(
|
||||
some_vairable_initialiation_fn(value_a, value_b))
|
||||
```
|
||||
|
||||
Args:
|
||||
op: An Op whose name is used to build the TensorInfo. The name that points
|
||||
to the Op could be fetched at run time in the Loader session.
|
||||
|
||||
Returns:
|
||||
A TensorInfo protocol buffer constructed based on the supplied argument.
|
||||
"""
|
||||
return meta_graph_pb2.TensorInfo(
|
||||
dtype=types_pb2.DT_INVALID,
|
||||
tensor_shape=tensor_shape.unknown_shape().as_proto(),
|
||||
name=op.name)
|
||||
|
||||
|
||||
@tf_export(v1=["saved_model.get_tensor_from_tensor_info",
|
||||
"saved_model.utils.get_tensor_from_tensor_info"])
|
||||
@deprecation.deprecated(
|
||||
|
@ -19,16 +19,41 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import utils
|
||||
|
||||
|
||||
class UtilsTest(test.TestCase):
|
||||
|
||||
def testBuildTensorInfoOp(self):
|
||||
x = constant_op.constant(1, name="x")
|
||||
y = constant_op.constant(2, name="y")
|
||||
z = control_flow_ops.group([x, y], name="op_z")
|
||||
z_op_info = utils.build_tensor_info_from_op(z)
|
||||
self.assertEqual("op_z", z_op_info.name)
|
||||
self.assertEqual(types_pb2.DT_INVALID, z_op_info.dtype)
|
||||
self.assertEqual(0, len(z_op_info.tensor_shape.dim))
|
||||
|
||||
def testBuildTensorInfoDefunOp(self):
|
||||
@function.defun
|
||||
def my_init_fn(x, y):
|
||||
self.x_var = x
|
||||
self.y_var = y
|
||||
|
||||
x = constant_op.constant(1, name="x")
|
||||
y = constant_op.constant(2, name="y")
|
||||
init_op_info = utils.build_tensor_info_from_op(my_init_fn(x, y))
|
||||
self.assertEqual("PartitionedFunctionCall", init_op_info.name)
|
||||
self.assertEqual(types_pb2.DT_INVALID, init_op_info.dtype)
|
||||
self.assertEqual(0, len(init_op_info.tensor_shape.dim))
|
||||
|
||||
def testBuildTensorInfoDense(self):
|
||||
x = array_ops.placeholder(dtypes.float32, 1, name="x")
|
||||
x_tensor_info = utils.build_tensor_info(x)
|
||||
|
Loading…
Reference in New Issue
Block a user