Simplify string representation of KerasTensor.
PiperOrigin-RevId: 325112498 Change-Id: Ibd4641faf6871ebd58477f9fe945cd2e0ac641f8
This commit is contained in:
parent
e4706fdf31
commit
83a1ac2d84
@ -1181,7 +1181,7 @@ def placeholder(shape=None,
|
||||
|
||||
>>> input_ph = tf.keras.backend.placeholder(shape=(2, 4, 5))
|
||||
>>> input_ph
|
||||
<KerasTensor: shape=(2, 4, 5) dtype=float32 (Symbolic value ...)>
|
||||
<KerasTensor: shape=(2, 4, 5) dtype=float32 (created by layer ...)>
|
||||
|
||||
"""
|
||||
if sparse and ragged:
|
||||
|
@ -229,12 +229,8 @@ class KerasTensor(object):
|
||||
|
||||
if hasattr(self, '_keras_history'):
|
||||
layer = self._keras_history.layer
|
||||
node_index = self._keras_history.node_index
|
||||
tensor_index = self._keras_history.tensor_index
|
||||
symbolic_description = (
|
||||
', description="Symbolic value %s from '
|
||||
'symbolic call %s of layer \'%s\'"' % (
|
||||
tensor_index, node_index, layer.name))
|
||||
', description="created by layer \'%s\'"' % (layer.name,))
|
||||
if self._inferred_value is not None:
|
||||
inferred_value_string = (
|
||||
', inferred_value=%s' % self._inferred_value)
|
||||
@ -254,11 +250,7 @@ class KerasTensor(object):
|
||||
|
||||
if hasattr(self, '_keras_history'):
|
||||
layer = self._keras_history.layer
|
||||
node_index = self._keras_history.node_index
|
||||
tensor_index = self._keras_history.tensor_index
|
||||
symbolic_description = (
|
||||
' (Symbolic value %s from symbolic call %s of layer \'%s\')' % (
|
||||
tensor_index, node_index, layer.name))
|
||||
symbolic_description = ' (created by layer \'%s\')' % (layer.name,)
|
||||
if self._inferred_value is not None:
|
||||
inferred_value_string = (
|
||||
' inferred_value=%s' % self._inferred_value)
|
||||
|
@ -68,21 +68,20 @@ class KerasTensorTest(test.TestCase):
|
||||
expected_str = (
|
||||
"KerasTensor(type_spec=TensorSpec(shape=(None, 3, 10), "
|
||||
"dtype=tf.float32, name=None), name='dense/BiasAdd:0', "
|
||||
"description=\"Symbolic value 0 from symbolic call 0 "
|
||||
"of layer 'dense'\")")
|
||||
"description=\"created by layer 'dense'\")")
|
||||
expected_repr = (
|
||||
"<KerasTensor: shape=(None, 3, 10) dtype=float32 (Symbolic value 0 "
|
||||
"from symbolic call 0 of layer 'dense')>")
|
||||
"<KerasTensor: shape=(None, 3, 10) dtype=float32 (created "
|
||||
"by layer 'dense')>")
|
||||
self.assertEqual(expected_str, str(kt))
|
||||
self.assertEqual(expected_repr, repr(kt))
|
||||
|
||||
kt = array_ops.reshape(kt, shape=(3, 5, 2))
|
||||
expected_str = (
|
||||
"KerasTensor(type_spec=TensorSpec(shape=(3, 5, 2), dtype=tf.float32, "
|
||||
"name=None), name='tf.reshape/Reshape:0', description=\"Symbolic "
|
||||
"value 0 from symbolic call 0 of layer 'tf.reshape'\")")
|
||||
expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (Symbolic "
|
||||
"value 0 from symbolic call 0 of layer 'tf.reshape')>")
|
||||
"name=None), name='tf.reshape/Reshape:0', description=\"created "
|
||||
"by layer 'tf.reshape'\")")
|
||||
expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (created "
|
||||
"by layer 'tf.reshape')>")
|
||||
self.assertEqual(expected_str, str(kt))
|
||||
self.assertEqual(expected_repr, repr(kt))
|
||||
|
||||
@ -90,12 +89,10 @@ class KerasTensorTest(test.TestCase):
|
||||
for i in range(3):
|
||||
expected_str = (
|
||||
"KerasTensor(type_spec=TensorSpec(shape=(5, 2), dtype=tf.float32, "
|
||||
"name=None), name='tf.unstack/unstack:%s', description=\"Symbolic "
|
||||
"value %s from symbolic call 0 of layer 'tf.unstack'\")"
|
||||
) % (i, i)
|
||||
"name=None), name='tf.unstack/unstack:%s', description=\"created "
|
||||
"by layer 'tf.unstack'\")" % (i,))
|
||||
expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 "
|
||||
"(Symbolic value %s from symbolic call 0 "
|
||||
"of layer 'tf.unstack')>") % i
|
||||
"(created by layer 'tf.unstack')>")
|
||||
self.assertEqual(expected_str, str(kts[i]))
|
||||
self.assertEqual(expected_repr, repr(kts[i]))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user