Add a meaningful repr to KerasTensors, Update the KerasTensor docstring, and update some of the doctests to not fail w/ the new KerasTensor repr.

PiperOrigin-RevId: 321822118
Change-Id: Iad58dddac5362301b6a06c532e58378f56e3b9ac
This commit is contained in:
Tomer Kaftan 2020-07-17 11:42:51 -07:00 committed by TensorFlower Gardener
parent af89635f3f
commit 391ebea266
5 changed files with 133 additions and 13 deletions

View File

@ -322,6 +322,21 @@ tf_py_test(
],
)
tf_py_test(
name = "keras_tensor_test",
size = "small",
srcs = ["keras_tensor_test.py"],
python_version = "PY3",
tags = [
"nomac", # TODO(mihaimaruseac): b/127695564
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "input_spec_test",
size = "small",

View File

@ -51,16 +51,15 @@ def keras_tensors_enabled():
class KerasTensor(object):
"""A representation of a Keras in/output during Functional API construction.
`KerasTensor`s are an alternative representation for Keras `Inputs`
and for intermediate outputs of layers during Functional API construction of
models. They are a lightweight data structure comprised of only the
`tf.TypeSpec` of the Tensor that will be consumed/produced in the
corresponding position of the model.
`KerasTensor`s are tensor-like objects that represent the symbolic inputs
and outputs of Keras layers during Functional model construction. They are
compromised of the `tf.TypeSpec` of the Tensor that will be
consumed/produced in the corresponding position of the model.
They implement just small subset of `tf.Tensor`'s attributes and
methods, and also overload
the same operators as `tf.Tensor` and automatically turn them into
Keras layers in the model.
They implement `tf.Tensor`'s attributes and methods, and also overload
the same operators as `tf.Tensor`. Passing a KerasTensor to a TF API that
supports dispatching will automatically turn that API call into a lambda
layer in the Functional model.
`KerasTensor`s are still internal-only and are a work in progress, but they
have several advantages over using a graph `tf.Tensor` to represent
@ -150,6 +149,27 @@ class KerasTensor(object):
else:
self._type_spec._shape = shape # pylint: disable=protected-access
def __repr__(self):
symbolic_description = ''
inferred_value_string = ''
if isinstance(self.type_spec, tensor_spec.TensorSpec):
type_spec_string = 'shape=%s dtype=%s' % (self.shape, self.dtype.name)
else:
type_spec_string = 'type_spec=%s' % self.type_spec
if hasattr(self, '_keras_history'):
layer = self._keras_history.layer
node_index = self._keras_history.node_index
tensor_index = self._keras_history.tensor_index
symbolic_description = (
' (Symbolic value %s from symbolic call %s of layer \'%s\')' % (
tensor_index, node_index, layer.name))
if self._inferred_shape_value is not None:
inferred_value_string = (
' inferred_value=\'%s\'' % self._inferred_shape_value)
return '<KerasTensor: %s%s%s>' % (
type_spec_string, inferred_value_string, symbolic_description)
@property
def dtype(self):
"""Returns the `dtype` of elements in the tensor."""

View File

@ -0,0 +1,85 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""InputSpec tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import layers
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import keras_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class KerasTensorTest(test.TestCase):
def test_repr(self):
kt = keras_tensor.KerasTensor(
type_spec=tensor_spec.TensorSpec(shape=(1, 2, 3), dtype=dtypes.float32))
expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>"
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
kt = keras_tensor.KerasTensor(
type_spec=tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32),
inferred_shape_value=[2, 3])
expected_repr = (
"<KerasTensor: shape=(2,) dtype=int32 inferred_value='[2, 3]'>")
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
kt = keras_tensor.KerasTensor(
type_spec=sparse_tensor.SparseTensorSpec(
shape=(1, 2, 3), dtype=dtypes.float32))
expected_repr = (
"<KerasTensor: type_spec=SparseTensorSpec("
"TensorShape([1, 2, 3]), tf.float32)>")
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
with testing_utils.use_keras_tensors_scope(True):
inp = layers.Input(shape=(3, 5))
kt = layers.Dense(10)(inp)
expected_repr = (
"<KerasTensor: shape=(None, 3, 10) dtype=float32 (Symbolic value 0 "
"from symbolic call 0 of layer 'dense')>")
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
kt = array_ops.reshape(kt, shape=(3, 5, 2))
expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (Symbolic "
"value 0 from symbolic call 0 of layer 'tf.reshape')>")
self.assertEqual(expected_repr, str(kt))
self.assertEqual(expected_repr, repr(kt))
kts = array_ops.unstack(kt)
for i in range(3):
expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 "
"(Symbolic value %s from symbolic call 0 "
"of layer 'tf.unstack')>" % i)
self.assertEqual(expected_repr, str(kts[i]))
self.assertEqual(expected_repr, repr(kts[i]))
if __name__ == "__main__":
ops.enable_eager_execution()
tensor_shape.enable_v2_tensorshape()
test.main()

View File

@ -73,7 +73,7 @@ class EinsumDense(Layer):
>>> input_tensor = tf.keras.Input(shape=[32])
>>> output_tensor = layer(input_tensor)
>>> output_tensor
<tf.Tensor '...' shape=(None, 64) dtype=...>
<... shape=(None, 64) dtype=...>
**Applying a dense layer to a sequence**
@ -89,7 +89,7 @@ class EinsumDense(Layer):
>>> input_tensor = tf.keras.Input(shape=[32, 128])
>>> output_tensor = layer(input_tensor)
>>> output_tensor
<tf.Tensor '...' shape=(None, 32, 64) dtype=...>
<... shape=(None, 32, 64) dtype=...>
**Applying a dense layer to a sequence using ellipses**
@ -106,7 +106,7 @@ class EinsumDense(Layer):
>>> input_tensor = tf.keras.Input(shape=[32, 128])
>>> output_tensor = layer(input_tensor)
>>> output_tensor
<tf.Tensor '...' shape=(None, 32, 64) dtype=...>
<... shape=(None, 32, 64) dtype=...>
"""
def __init__(self,

View File

@ -591,7 +591,7 @@ def shape_v2(input, out_type=dtypes.int32, name=None):
>>> a = tf.keras.layers.Input((None, 10))
>>> tf.shape(a)
<tf.Tensor ... shape=(3,) dtype=int32>
<... shape=(3,) dtype=int32...>
In these cases, using `tf.Tensor.shape` will return more informative results.