From 2aaa53e21f5d12c6de74a7d73525f9fc227b13bb Mon Sep 17 00:00:00 2001 From: Harry Slatyer Date: Fri, 3 Jan 2020 11:23:20 +1100 Subject: [PATCH 1/2] Support Identity in tensor_util.constant_value. This looks just the same as StopGradient, since for the purposes of forward-propagated values the two are identical. --- tensorflow/python/framework/tensor_util.py | 2 ++ tensorflow/python/framework/tensor_util_test.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 4fcee63f464..aeda811e2fc 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -790,6 +790,8 @@ def _ConstantValue(tensor, partial): return np.not_equal(value1, value2) elif tensor.op.type == "StopGradient": return constant_value(tensor.op.inputs[0], partial) + elif tensor.op.type == "Identity": + return constant_value(tensor.op.inputs[0], partial) else: return None diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index b2ab779386b..9aa88eda6ef 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -979,6 +979,12 @@ class ConstantValueTest(test.TestCase): c_val = tensor_util.constant_value(tf_val) self.assertAllEqual(input_, c_val) + def testIdentity(self): + input_ = np.random.rand(4, 7) + tf_val = array_ops.identity(input_) + c_val = tensor_util.constant_value(tf_val) + self.assertAllEqual(input_, c_val) + def testLiteral(self): x = "hi" self.assertIs(x, tensor_util.constant_value(x)) From 0e59af232b6c2ad33dd6ff35d85e68af51d22c32 Mon Sep 17 00:00:00 2001 From: Harry Slatyer Date: Wed, 8 Jan 2020 09:21:58 +1100 Subject: [PATCH 2/2] Handle error messages with line breaks in confusion_matrix_test Depending on the specifics of the condition (in particular whether it can be evaluated statically), the error message produced by an assertion can either be shown on one line or split across multiple lines. In the latter case, the use of a .* regex fails, because the . doesn't match the line breaks. To fix that we can just use [\s\S]* instead. --- tensorflow/python/kernel_tests/confusion_matrix_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index c1178253a4b..810d2e60961 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -188,7 +188,7 @@ class ConfusionMatrixTest(test.TestCase): def testLabelsTooLarge(self): labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32) predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32) - with self.assertRaisesOpError("`labels`.*x < y"): + with self.assertRaisesOpError("`labels`[\s\S]*x < y"): self._testConfMatrix( labels=labels, predictions=predictions, num_classes=3, truth=None) @@ -203,7 +203,7 @@ class ConfusionMatrixTest(test.TestCase): def testPredictionsTooLarge(self): labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32) predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32) - with self.assertRaisesOpError("`predictions`.*x < y"): + with self.assertRaisesOpError("`predictions`[\s\S]*x < y"): self._testConfMatrix( labels=labels, predictions=predictions, num_classes=3, truth=None)