diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index ffe3a8cb845..647540fc612 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -790,6 +790,8 @@ def _ConstantValue(tensor, partial): return np.not_equal(value1, value2) elif tensor.op.type == "StopGradient": return constant_value(tensor.op.inputs[0], partial) + elif tensor.op.type == "Identity": + return constant_value(tensor.op.inputs[0], partial) else: return None diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index b2ab779386b..9aa88eda6ef 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -979,6 +979,12 @@ class ConstantValueTest(test.TestCase): c_val = tensor_util.constant_value(tf_val) self.assertAllEqual(input_, c_val) + def testIdentity(self): + input_ = np.random.rand(4, 7) + tf_val = array_ops.identity(input_) + c_val = tensor_util.constant_value(tf_val) + self.assertAllEqual(input_, c_val) + def testLiteral(self): x = "hi" self.assertIs(x, tensor_util.constant_value(x)) diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index c1178253a4b..810d2e60961 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -188,7 +188,7 @@ class ConfusionMatrixTest(test.TestCase): def testLabelsTooLarge(self): labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32) predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32) - with self.assertRaisesOpError("`labels`.*x < y"): + with self.assertRaisesOpError("`labels`[\s\S]*x < y"): self._testConfMatrix( labels=labels, predictions=predictions, num_classes=3, truth=None) @@ -203,7 +203,7 @@ class ConfusionMatrixTest(test.TestCase): def testPredictionsTooLarge(self): labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32) predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32) - with self.assertRaisesOpError("`predictions`.*x < y"): + with self.assertRaisesOpError("`predictions`[\s\S]*x < y"): self._testConfMatrix( labels=labels, predictions=predictions, num_classes=3, truth=None)