Re-apply github tensorflow/pull/22264/commits/51d72a7d7f74784b68916819edd04e890b36f957
PiperOrigin-RevId: 325921879 Change-Id: I703edc9e0f381d64784027eb9457bc10f5e5aef8
This commit is contained in:
parent
36d55f1c56
commit
6fb229b3e7
@ -48,6 +48,7 @@ py_strict_library(
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python/saved_model:signature_def_utils",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -69,6 +70,7 @@ py_strict_test(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:metrics",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/saved_model:signature_constants",
|
||||
],
|
||||
|
@ -26,6 +26,7 @@ import six
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
|
||||
|
||||
@ -342,16 +343,16 @@ class _SupervisedOutput(ExportOutput):
|
||||
raise ValueError(
|
||||
'{} output value must be a Tensor; got {}.'.format(
|
||||
key, metric_val))
|
||||
if (not isinstance(metric_op, ops.Tensor) and
|
||||
not isinstance(metric_op, ops.Operation)):
|
||||
if not (tensor_util.is_tensor(metric_op) or
|
||||
isinstance(metric_op, ops.Operation)):
|
||||
raise ValueError(
|
||||
'{} update_op must be a Tensor or Operation; got {}.'.format(
|
||||
key, metric_op))
|
||||
|
||||
# We must wrap any ops in a Tensor before export, as the SignatureDef
|
||||
# proto expects tensors only. See b/109740581
|
||||
# We must wrap any ops (or variables) in a Tensor before export, as the
|
||||
# SignatureDef proto expects tensors only. See b/109740581
|
||||
metric_op_tensor = metric_op
|
||||
if isinstance(metric_op, ops.Operation):
|
||||
if not isinstance(metric_op, ops.Tensor):
|
||||
with ops.control_dependencies([metric_op]):
|
||||
metric_op_tensor = constant_op.constant([], name='metric_op_wrapper')
|
||||
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import metrics as metrics_module
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
|
||||
@ -373,10 +374,16 @@ class SupervisedOutputTest(test.TestCase):
|
||||
mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
|
||||
metrics = {
|
||||
'metrics_1': (mean, update_op),
|
||||
'metrics_2': (constant_op.constant([0]), control_flow_ops.no_op())
|
||||
'metrics_2': (constant_op.constant([0]), control_flow_ops.no_op()),
|
||||
# Keras metric's update_state() could return a Variable, rather than
|
||||
# an Operation or Tensor.
|
||||
'keras_1': (constant_op.constant([0.5]),
|
||||
variables.Variable(1.0, name='AssignAddVariableOp_3'))
|
||||
}
|
||||
|
||||
outputter = MockSupervisedOutput(loss, predictions, metrics)
|
||||
# If we get there, it means constructor succeeded; which is sufficient
|
||||
# for testing the constructor.
|
||||
|
||||
self.assertTrue(outputter.metrics['metrics_1/update_op'].name.startswith(
|
||||
'mean/update_op'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user