Further refine KerasTensors to match behavior in head, and fix several docstring code snippets so they will be valid when we enable KerasTensors

PiperOrigin-RevId: 320491351
Change-Id: I9b91727cd719c92e10171708b25341888bab4d98
This commit is contained in:
Tomer Kaftan 2020-07-09 15:54:47 -07:00 committed by TensorFlower Gardener
parent b75de2d6ef
commit 3304c4a4a8
4 changed files with 22 additions and 17 deletions

View File

@ -170,10 +170,9 @@ def constant(value, dtype=None, shape=None, name="Const"):
Note: All eager `tf.Tensor` values are immutable (in contrast to
`tf.Variable`). There is nothing especially _constant_ about the value
returned from `tf.constant`. This function it is not fundamentally different
from `tf.convert_to_tensor`. The name `tf.constant` comes from the symbolic
APIs (like `tf.data` or keras functional models) where the `value` is embeded
in a `Const` node in the `tf.Graph`. `tf.constant` is useful for asserting
that the value can be embedded that way.
from `tf.convert_to_tensor`. The name `tf.constant` comes from the `value`
being embeded in a `Const` node in the `tf.Graph`. `tf.constant` is useful
for asserting that the value can be embedded that way.
If the argument `dtype` is not specified, then the type is inferred from
the type of `value`.
@ -220,11 +219,12 @@ def constant(value, dtype=None, shape=None, name="Const"):
But, since `tf.constant` embeds the value in the `tf.Graph` this fails for
symbolic tensors:
>>> i = tf.keras.layers.Input(shape=[None, None])
>>> t = tf.constant(i)
>>> with tf.compat.v1.Graph().as_default():
... i = tf.compat.v1.placeholder(shape=[None, None], dtype=tf.float32)
... t = tf.constant(i)
Traceback (most recent call last):
...
NotImplementedError: ...
TypeError: ...
`tf.constant` will _always_ create CPU (host) tensors. In order to create
tensors on other devices, use `tf.identity`. (If the `value` is an eager
@ -236,8 +236,9 @@ def constant(value, dtype=None, shape=None, name="Const"):
* It has no `shape` argument.
* Symbolic tensors are allowed to pass through.
>>> i = tf.keras.layers.Input(shape=[None, None])
>>> t = tf.convert_to_tensor(i)
>>> with tf.compat.v1.Graph().as_default():
... i = tf.compat.v1.placeholder(shape=[None, None], dtype=tf.float32)
... t = tf.convert_to_tensor(i)
* `tf.fill`: differs in a few ways:
* `tf.constant` supports arbitrary constants, not just uniform scalar

View File

@ -853,6 +853,9 @@ def is_sparse(tensor):
True
"""
spec = getattr(tensor, '_type_spec', None)
if spec is not None:
return isinstance(spec, sparse_tensor.SparseTensorSpec)
return isinstance(tensor, sparse_tensor.SparseTensor)
@ -1128,13 +1131,14 @@ def is_keras_tensor(x):
True
"""
if keras_tensor.keras_tensors_enabled():
return isinstance(x, keras_tensor.KerasTensor)
if not isinstance(x,
(ops.Tensor, variables_module.Variable,
sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor,
keras_tensor.KerasTensor)):
raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
'`. Expected a symbolic tensor instance.')
if keras_tensor.keras_tensors_enabled():
return isinstance(x, keras_tensor.KerasTensor)
return hasattr(x, '_keras_history')

View File

@ -1400,9 +1400,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
>>> outputs = tf.keras.layers.Dense(1)(x)
>>> model = tf.keras.Model(inputs, outputs)
>>> # Activity regularization.
>>> len(model.losses)
0
>>> model.add_loss(tf.abs(tf.reduce_mean(x)))
>>> model.losses
[<tf.Tensor 'Abs:0' shape=() dtype=float32>]
>>> len(model.losses)
1
>>> inputs = tf.keras.Input(shape=(10,))
>>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')

View File

@ -102,8 +102,6 @@ class KerasTensor(object):
self._type_spec = type_spec
self._inferred_shape_value = inferred_shape_value
if name is None and hasattr(type_spec, 'name'):
name = type_spec.name
self._name = name
@property
@ -150,7 +148,7 @@ class KerasTensor(object):
"compatible with supplied shape %s" %
(self.shape, shape))
else:
self._internal_type_spec._shape = shape # pylint: disable=protected-access
self._type_spec._shape = shape # pylint: disable=protected-access
@property
def dtype(self):