_ConstantValue can now see through tf.identity ops

This commit is contained in:
ngc92 2020-03-24 20:00:00 +02:00
parent 4a441b2364
commit 7f389c2ac3
2 changed files with 8 additions and 0 deletions

View File

@ -791,6 +791,8 @@ def _ConstantValue(tensor, partial):
return np.not_equal(value1, value2)
elif tensor.op.type == "StopGradient":
return constant_value(tensor.op.inputs[0], partial)
elif tensor.op.type == "Identity":
return constant_value(tensor.op.inputs[0], partial)
else:
return None

View File

@ -979,6 +979,12 @@ class ConstantValueTest(test.TestCase):
c_val = tensor_util.constant_value(tf_val)
self.assertAllEqual(input_, c_val)
def testIdentity(self):
input_ = np.random.rand(4, 7)
tf_val = array_ops.identity(input_)
c_val = tensor_util.constant_value(tf_val)
self.assertAllEqual(input_, c_val)
def testLiteral(self):
x = "hi"
self.assertIs(x, tensor_util.constant_value(x))