Simplify string representation of KerasTensor.

PiperOrigin-RevId: 325112498
Change-Id: Ibd4641faf6871ebd58477f9fe945cd2e0ac641f8
This commit is contained in:
Francois Chollet 2020-08-05 15:18:38 -07:00 committed by TensorFlower Gardener
parent e4706fdf31
commit 83a1ac2d84
3 changed files with 13 additions and 24 deletions

View File

@ -1181,7 +1181,7 @@ def placeholder(shape=None,
>>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
>>> input_ph
<KerasTensor: shape=(2, 4, 5) dtype=float32 (Symbolic value ...)>
<KerasTensor: shape=(2, 4, 5) dtype=float32 (created by layer ...)>
"""
if sparse and ragged:

View File

@ -229,12 +229,8 @@ class KerasTensor(object):
if hasattr(self, '_keras_history'):
layer = self._keras_history.layer
node_index = self._keras_history.node_index
tensor_index = self._keras_history.tensor_index
symbolic_description = (
', description="Symbolic value %s from '
'symbolic call %s of layer \'%s\'"' % (
tensor_index, node_index, layer.name))
', description="created by layer \'%s\'"' % (layer.name,))
if self._inferred_value is not None:
inferred_value_string = (
', inferred_value=%s' % self._inferred_value)
@ -254,11 +250,7 @@ class KerasTensor(object):
if hasattr(self, '_keras_history'):
layer = self._keras_history.layer
node_index = self._keras_history.node_index
tensor_index = self._keras_history.tensor_index
symbolic_description = (
' (Symbolic value %s from symbolic call %s of layer \'%s\')' % (
tensor_index, node_index, layer.name))
symbolic_description = ' (created by layer \'%s\')' % (layer.name,)
if self._inferred_value is not None:
inferred_value_string = (
' inferred_value=%s' % self._inferred_value)

View File

@ -68,21 +68,20 @@ class KerasTensorTest(test.TestCase):
expected_str = (
"KerasTensor(type_spec=TensorSpec(shape=(None, 3, 10), "
"dtype=tf.float32, name=None), name='dense/BiasAdd:0', "
"description=\"Symbolic value 0 from symbolic call 0 "
"of layer 'dense'\")")
"description=\"created by layer 'dense'\")")
expected_repr = (
"<KerasTensor: shape=(None, 3, 10) dtype=float32 (Symbolic value 0 "
"from symbolic call 0 of layer 'dense')>")
"<KerasTensor: shape=(None, 3, 10) dtype=float32 (created "
"by layer 'dense')>")
self.assertEqual(expected_str, str(kt))
self.assertEqual(expected_repr, repr(kt))
kt = array_ops.reshape(kt, shape=(3, 5, 2))
expected_str = (
"KerasTensor(type_spec=TensorSpec(shape=(3, 5, 2), dtype=tf.float32, "
"name=None), name='tf.reshape/Reshape:0', description=\"Symbolic "
"value 0 from symbolic call 0 of layer 'tf.reshape'\")")
expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (Symbolic "
"value 0 from symbolic call 0 of layer 'tf.reshape')>")
"name=None), name='tf.reshape/Reshape:0', description=\"created "
"by layer 'tf.reshape'\")")
expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (created "
"by layer 'tf.reshape')>")
self.assertEqual(expected_str, str(kt))
self.assertEqual(expected_repr, repr(kt))
@ -90,12 +89,10 @@ class KerasTensorTest(test.TestCase):
for i in range(3):
expected_str = (
"KerasTensor(type_spec=TensorSpec(shape=(5, 2), dtype=tf.float32, "
"name=None), name='tf.unstack/unstack:%s', description=\"Symbolic "
"value %s from symbolic call 0 of layer 'tf.unstack'\")"
) % (i, i)
"name=None), name='tf.unstack/unstack:%s', description=\"created "
"by layer 'tf.unstack'\")" % (i,))
expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 "
"(Symbolic value %s from symbolic call 0 "
"of layer 'tf.unstack')>") % i
"(created by layer 'tf.unstack')>")
self.assertEqual(expected_str, str(kts[i]))
self.assertEqual(expected_repr, repr(kts[i]))