Add a meaningful repr to KerasTensors, Update the KerasTensor docstring, and update some of the doctests to not fail w/ the new KerasTensor repr.
PiperOrigin-RevId: 321822118 Change-Id: Iad58dddac5362301b6a06c532e58378f56e3b9ac
This commit is contained in:
parent
af89635f3f
commit
391ebea266
@ -322,6 +322,21 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "keras_tensor_test",
|
||||
size = "small",
|
||||
srcs = ["keras_tensor_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "input_spec_test",
|
||||
size = "small",
|
||||
|
||||
@ -51,16 +51,15 @@ def keras_tensors_enabled():
|
||||
class KerasTensor(object):
|
||||
"""A representation of a Keras in/output during Functional API construction.
|
||||
|
||||
`KerasTensor`s are an alternative representation for Keras `Inputs`
|
||||
and for intermediate outputs of layers during Functional API construction of
|
||||
models. They are a lightweight data structure comprised of only the
|
||||
`tf.TypeSpec` of the Tensor that will be consumed/produced in the
|
||||
corresponding position of the model.
|
||||
`KerasTensor`s are tensor-like objects that represent the symbolic inputs
|
||||
and outputs of Keras layers during Functional model construction. They are
|
||||
compromised of the `tf.TypeSpec` of the Tensor that will be
|
||||
consumed/produced in the corresponding position of the model.
|
||||
|
||||
They implement just small subset of `tf.Tensor`'s attributes and
|
||||
methods, and also overload
|
||||
the same operators as `tf.Tensor` and automatically turn them into
|
||||
Keras layers in the model.
|
||||
They implement `tf.Tensor`'s attributes and methods, and also overload
|
||||
the same operators as `tf.Tensor`. Passing a KerasTensor to a TF API that
|
||||
supports dispatching will automatically turn that API call into a lambda
|
||||
layer in the Functional model.
|
||||
|
||||
`KerasTensor`s are still internal-only and are a work in progress, but they
|
||||
have several advantages over using a graph `tf.Tensor` to represent
|
||||
@ -150,6 +149,27 @@ class KerasTensor(object):
|
||||
else:
|
||||
self._type_spec._shape = shape # pylint: disable=protected-access
|
||||
|
||||
def __repr__(self):
|
||||
symbolic_description = ''
|
||||
inferred_value_string = ''
|
||||
if isinstance(self.type_spec, tensor_spec.TensorSpec):
|
||||
type_spec_string = 'shape=%s dtype=%s' % (self.shape, self.dtype.name)
|
||||
else:
|
||||
type_spec_string = 'type_spec=%s' % self.type_spec
|
||||
|
||||
if hasattr(self, '_keras_history'):
|
||||
layer = self._keras_history.layer
|
||||
node_index = self._keras_history.node_index
|
||||
tensor_index = self._keras_history.tensor_index
|
||||
symbolic_description = (
|
||||
' (Symbolic value %s from symbolic call %s of layer \'%s\')' % (
|
||||
tensor_index, node_index, layer.name))
|
||||
if self._inferred_shape_value is not None:
|
||||
inferred_value_string = (
|
||||
' inferred_value=\'%s\'' % self._inferred_shape_value)
|
||||
return '<KerasTensor: %s%s%s>' % (
|
||||
type_spec_string, inferred_value_string, symbolic_description)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Returns the `dtype` of elements in the tensor."""
|
||||
|
||||
85
tensorflow/python/keras/engine/keras_tensor_test.py
Normal file
85
tensorflow/python/keras/engine/keras_tensor_test.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""InputSpec tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import keras_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class KerasTensorTest(test.TestCase):
|
||||
|
||||
def test_repr(self):
|
||||
kt = keras_tensor.KerasTensor(
|
||||
type_spec=tensor_spec.TensorSpec(shape=(1, 2, 3), dtype=dtypes.float32))
|
||||
expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>"
|
||||
self.assertEqual(expected_repr, str(kt))
|
||||
self.assertEqual(expected_repr, repr(kt))
|
||||
|
||||
kt = keras_tensor.KerasTensor(
|
||||
type_spec=tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32),
|
||||
inferred_shape_value=[2, 3])
|
||||
expected_repr = (
|
||||
"<KerasTensor: shape=(2,) dtype=int32 inferred_value='[2, 3]'>")
|
||||
self.assertEqual(expected_repr, str(kt))
|
||||
self.assertEqual(expected_repr, repr(kt))
|
||||
|
||||
kt = keras_tensor.KerasTensor(
|
||||
type_spec=sparse_tensor.SparseTensorSpec(
|
||||
shape=(1, 2, 3), dtype=dtypes.float32))
|
||||
expected_repr = (
|
||||
"<KerasTensor: type_spec=SparseTensorSpec("
|
||||
"TensorShape([1, 2, 3]), tf.float32)>")
|
||||
self.assertEqual(expected_repr, str(kt))
|
||||
self.assertEqual(expected_repr, repr(kt))
|
||||
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
inp = layers.Input(shape=(3, 5))
|
||||
kt = layers.Dense(10)(inp)
|
||||
expected_repr = (
|
||||
"<KerasTensor: shape=(None, 3, 10) dtype=float32 (Symbolic value 0 "
|
||||
"from symbolic call 0 of layer 'dense')>")
|
||||
self.assertEqual(expected_repr, str(kt))
|
||||
self.assertEqual(expected_repr, repr(kt))
|
||||
|
||||
kt = array_ops.reshape(kt, shape=(3, 5, 2))
|
||||
expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (Symbolic "
|
||||
"value 0 from symbolic call 0 of layer 'tf.reshape')>")
|
||||
self.assertEqual(expected_repr, str(kt))
|
||||
self.assertEqual(expected_repr, repr(kt))
|
||||
|
||||
kts = array_ops.unstack(kt)
|
||||
for i in range(3):
|
||||
expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 "
|
||||
"(Symbolic value %s from symbolic call 0 "
|
||||
"of layer 'tf.unstack')>" % i)
|
||||
self.assertEqual(expected_repr, str(kts[i]))
|
||||
self.assertEqual(expected_repr, repr(kts[i]))
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
tensor_shape.enable_v2_tensorshape()
|
||||
test.main()
|
||||
@ -73,7 +73,7 @@ class EinsumDense(Layer):
|
||||
>>> input_tensor = tf.keras.Input(shape=[32])
|
||||
>>> output_tensor = layer(input_tensor)
|
||||
>>> output_tensor
|
||||
<tf.Tensor '...' shape=(None, 64) dtype=...>
|
||||
<... shape=(None, 64) dtype=...>
|
||||
|
||||
**Applying a dense layer to a sequence**
|
||||
|
||||
@ -89,7 +89,7 @@ class EinsumDense(Layer):
|
||||
>>> input_tensor = tf.keras.Input(shape=[32, 128])
|
||||
>>> output_tensor = layer(input_tensor)
|
||||
>>> output_tensor
|
||||
<tf.Tensor '...' shape=(None, 32, 64) dtype=...>
|
||||
<... shape=(None, 32, 64) dtype=...>
|
||||
|
||||
**Applying a dense layer to a sequence using ellipses**
|
||||
|
||||
@ -106,7 +106,7 @@ class EinsumDense(Layer):
|
||||
>>> input_tensor = tf.keras.Input(shape=[32, 128])
|
||||
>>> output_tensor = layer(input_tensor)
|
||||
>>> output_tensor
|
||||
<tf.Tensor '...' shape=(None, 32, 64) dtype=...>
|
||||
<... shape=(None, 32, 64) dtype=...>
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -591,7 +591,7 @@ def shape_v2(input, out_type=dtypes.int32, name=None):
|
||||
|
||||
>>> a = tf.keras.layers.Input((None, 10))
|
||||
>>> tf.shape(a)
|
||||
<tf.Tensor ... shape=(3,) dtype=int32>
|
||||
<... shape=(3,) dtype=int32...>
|
||||
|
||||
In these cases, using `tf.Tensor.shape` will return more informative results.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user