Re-apply github tensorflow/pull/22264/commits/51d72a7d7f74784b68916819edd04e890b36f957

PiperOrigin-RevId: 325921879
Change-Id: I703edc9e0f381d64784027eb9457bc10f5e5aef8
This commit is contained in:
A. Unique TensorFlower 2020-08-10 17:24:25 -07:00 committed by TensorFlower Gardener
parent 36d55f1c56
commit 6fb229b3e7
3 changed files with 16 additions and 6 deletions

View File

@ -48,6 +48,7 @@ py_strict_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python/saved_model:signature_def_utils",
"@six_archive//:six",
],
@ -69,6 +70,7 @@ py_strict_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:metrics",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/saved_model:signature_constants",
],

View File

@ -26,6 +26,7 @@ import six
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.saved_model import signature_def_utils
@ -342,16 +343,16 @@ class _SupervisedOutput(ExportOutput):
raise ValueError(
'{} output value must be a Tensor; got {}.'.format(
key, metric_val))
if (not isinstance(metric_op, ops.Tensor) and
not isinstance(metric_op, ops.Operation)):
if not (tensor_util.is_tensor(metric_op) or
isinstance(metric_op, ops.Operation)):
raise ValueError(
'{} update_op must be a Tensor or Operation; got {}.'.format(
key, metric_op))
# We must wrap any ops in a Tensor before export, as the SignatureDef
# proto expects tensors only. See b/109740581
# We must wrap any ops (or variables) in a Tensor before export, as the
# SignatureDef proto expects tensors only. See b/109740581
metric_op_tensor = metric_op
if isinstance(metric_op, ops.Operation):
if not isinstance(metric_op, ops.Tensor):
with ops.control_dependencies([metric_op]):
metric_op_tensor = constant_op.constant([], name='metric_op_wrapper')

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
@ -373,10 +374,16 @@ class SupervisedOutputTest(test.TestCase):
mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
metrics = {
'metrics_1': (mean, update_op),
'metrics_2': (constant_op.constant([0]), control_flow_ops.no_op())
'metrics_2': (constant_op.constant([0]), control_flow_ops.no_op()),
# Keras metric's update_state() could return a Variable, rather than
# an Operation or Tensor.
'keras_1': (constant_op.constant([0.5]),
variables.Variable(1.0, name='AssignAddVariableOp_3'))
}
outputter = MockSupervisedOutput(loss, predictions, metrics)
# If we get there, it means constructor succeeded; which is sufficient
# for testing the constructor.
self.assertTrue(outputter.metrics['metrics_1/update_op'].name.startswith(
'mean/update_op'))