Merge pull request #35552 from charmasaur:identity_in_constant_value

PiperOrigin-RevId: 292937401
Change-Id: If15a66bfdf6243c084e4d4b214ab8378eaf6ed00
This commit is contained in:
TensorFlower Gardener 2020-02-03 09:38:53 -08:00
commit af844bd2df
3 changed files with 10 additions and 2 deletions

View File

@ -790,6 +790,8 @@ def _ConstantValue(tensor, partial):
return np.not_equal(value1, value2)
elif tensor.op.type == "StopGradient":
return constant_value(tensor.op.inputs[0], partial)
elif tensor.op.type == "Identity":
return constant_value(tensor.op.inputs[0], partial)
else:
return None

View File

@ -979,6 +979,12 @@ class ConstantValueTest(test.TestCase):
c_val = tensor_util.constant_value(tf_val)
self.assertAllEqual(input_, c_val)
def testIdentity(self):
input_ = np.random.rand(4, 7)
tf_val = array_ops.identity(input_)
c_val = tensor_util.constant_value(tf_val)
self.assertAllEqual(input_, c_val)
def testLiteral(self):
x = "hi"
self.assertIs(x, tensor_util.constant_value(x))

View File

@ -188,7 +188,7 @@ class ConfusionMatrixTest(test.TestCase):
def testLabelsTooLarge(self):
labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32)
with self.assertRaisesOpError("`labels`.*x < y"):
with self.assertRaisesOpError("`labels`[\s\S]*x < y"):
self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None)
@ -203,7 +203,7 @@ class ConfusionMatrixTest(test.TestCase):
def testPredictionsTooLarge(self):
labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32)
with self.assertRaisesOpError("`predictions`.*x < y"):
with self.assertRaisesOpError("`predictions`[\s\S]*x < y"):
self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None)