Add a meaningful repr to KerasTensors, Update the KerasTensor docstring, and update some of the doctests to not fail w/ the new KerasTensor repr.

PiperOrigin-RevId: 321822118
Change-Id: Iad58dddac5362301b6a06c532e58378f56e3b9ac
This commit is contained in:
Tomer Kaftan 2020-07-17 11:42:51 -07:00 committed by TensorFlower Gardener
parent af89635f3f
commit 391ebea266
5 changed files with 133 additions and 13 deletions

View File

@ -322,6 +322,21 @@ tf_py_test(
], ],
) )
tf_py_test(
name = "keras_tensor_test",
size = "small",
srcs = ["keras_tensor_test.py"],
python_version = "PY3",
tags = [
"nomac", # TODO(mihaimaruseac): b/127695564
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test( tf_py_test(
name = "input_spec_test", name = "input_spec_test",
size = "small", size = "small",

View File

@ -51,16 +51,15 @@ def keras_tensors_enabled():
class KerasTensor(object): class KerasTensor(object):
"""A representation of a Keras in/output during Functional API construction. """A representation of a Keras in/output during Functional API construction.
`KerasTensor`s are an alternative representation for Keras `Inputs` `KerasTensor`s are tensor-like objects that represent the symbolic inputs
and for intermediate outputs of layers during Functional API construction of and outputs of Keras layers during Functional model construction. They are
models. They are a lightweight data structure comprised of only the compromised of the `tf.TypeSpec` of the Tensor that will be
`tf.TypeSpec` of the Tensor that will be consumed/produced in the consumed/produced in the corresponding position of the model.
corresponding position of the model.
They implement just small subset of `tf.Tensor`'s attributes and They implement `tf.Tensor`'s attributes and methods, and also overload
methods, and also overload the same operators as `tf.Tensor`. Passing a KerasTensor to a TF API that
the same operators as `tf.Tensor` and automatically turn them into supports dispatching will automatically turn that API call into a lambda
Keras layers in the model. layer in the Functional model.
`KerasTensor`s are still internal-only and are a work in progress, but they `KerasTensor`s are still internal-only and are a work in progress, but they
have several advantages over using a graph `tf.Tensor` to represent have several advantages over using a graph `tf.Tensor` to represent
@ -150,6 +149,27 @@ class KerasTensor(object):
else: else:
self._type_spec._shape = shape # pylint: disable=protected-access self._type_spec._shape = shape # pylint: disable=protected-access
def __repr__(self):
symbolic_description = ''
inferred_value_string = ''
if isinstance(self.type_spec, tensor_spec.TensorSpec):
type_spec_string = 'shape=%s dtype=%s' % (self.shape, self.dtype.name)
else:
type_spec_string = 'type_spec=%s' % self.type_spec
if hasattr(self, '_keras_history'):
layer = self._keras_history.layer
node_index = self._keras_history.node_index
tensor_index = self._keras_history.tensor_index
symbolic_description = (
' (Symbolic value %s from symbolic call %s of layer \'%s\')' % (
tensor_index, node_index, layer.name))
if self._inferred_shape_value is not None:
inferred_value_string = (
' inferred_value=\'%s\'' % self._inferred_shape_value)
return '<KerasTensor: %s%s%s>' % (
type_spec_string, inferred_value_string, symbolic_description)
@property @property
def dtype(self): def dtype(self):
"""Returns the `dtype` of elements in the tensor.""" """Returns the `dtype` of elements in the tensor."""

View File

@ -0,0 +1,85 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""InputSpec tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import layers
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class KerasTensorTest(test.TestCase):
def test_repr(self):
kt = keras_tensor.KerasTensor(
type_spec=tensor_spec.TensorSpec(shape=(1, 2, 3), dtype=dtypes.float32))
expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>"
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
kt = keras_tensor.KerasTensor(
type_spec=tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32),
inferred_shape_value=[2, 3])
expected_repr = (
"<KerasTensor: shape=(2,) dtype=int32 inferred_value='[2, 3]'>")
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
kt = keras_tensor.KerasTensor(
type_spec=sparse_tensor.SparseTensorSpec(
shape=(1, 2, 3), dtype=dtypes.float32))
expected_repr = (
"<KerasTensor: type_spec=SparseTensorSpec("
"TensorShape([1, 2, 3]), tf.float32)>")
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
with testing_utils.use_keras_tensors_scope(True):
inp = layers.Input(shape=(3, 5))
kt = layers.Dense(10)(inp)
expected_repr = (
"<KerasTensor: shape=(None, 3, 10) dtype=float32 (Symbolic value 0 "
"from symbolic call 0 of layer 'dense')>")
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
kt = array_ops.reshape(kt, shape=(3, 5, 2))
expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (Symbolic "
"value 0 from symbolic call 0 of layer 'tf.reshape')>")
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
kts = array_ops.unstack(kt)
for i in range(3):
expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 "
"(Symbolic value %s from symbolic call 0 "
"of layer 'tf.unstack')>" % i)
self.assertEqual(expected_repr, str(kts[i]))
self.assertEqual(expected_repr, repr(kts[i]))
if __name__ == "__main__":
ops.enable_eager_execution()
tensor_shape.enable_v2_tensorshape()
test.main()

View File

@ -73,7 +73,7 @@ class EinsumDense(Layer):
>>> input_tensor = tf.keras.Input(shape=[32]) >>> input_tensor = tf.keras.Input(shape=[32])
>>> output_tensor = layer(input_tensor) >>> output_tensor = layer(input_tensor)
>>> output_tensor >>> output_tensor
<tf.Tensor '...' shape=(None, 64) dtype=...> <... shape=(None, 64) dtype=...>
**Applying a dense layer to a sequence** **Applying a dense layer to a sequence**
@ -89,7 +89,7 @@ class EinsumDense(Layer):
>>> input_tensor = tf.keras.Input(shape=[32, 128]) >>> input_tensor = tf.keras.Input(shape=[32, 128])
>>> output_tensor = layer(input_tensor) >>> output_tensor = layer(input_tensor)
>>> output_tensor >>> output_tensor
<tf.Tensor '...' shape=(None, 32, 64) dtype=...> <... shape=(None, 32, 64) dtype=...>
**Applying a dense layer to a sequence using ellipses** **Applying a dense layer to a sequence using ellipses**
@ -106,7 +106,7 @@ class EinsumDense(Layer):
>>> input_tensor = tf.keras.Input(shape=[32, 128]) >>> input_tensor = tf.keras.Input(shape=[32, 128])
>>> output_tensor = layer(input_tensor) >>> output_tensor = layer(input_tensor)
>>> output_tensor >>> output_tensor
<tf.Tensor '...' shape=(None, 32, 64) dtype=...> <... shape=(None, 32, 64) dtype=...>
""" """
def __init__(self, def __init__(self,

View File

@ -591,7 +591,7 @@ def shape_v2(input, out_type=dtypes.int32, name=None):
>>> a = tf.keras.layers.Input((None, 10)) >>> a = tf.keras.layers.Input((None, 10))
>>> tf.shape(a) >>> tf.shape(a)
<tf.Tensor ... shape=(3,) dtype=int32> <... shape=(3,) dtype=int32...>
In these cases, using `tf.Tensor.shape` will return more informative results. In these cases, using `tf.Tensor.shape` will return more informative results.