Fix Keras Input layer with sparse=True
PiperOrigin-RevId: 209631832
This commit is contained in:
parent
d81e875dd6
commit
3c74f977c6
@ -183,11 +183,32 @@ class SparseTensor(_TensorLike):
|
||||
"""A 1-D Tensor of int64 representing the shape of the dense tensor."""
|
||||
return self._dense_shape
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Get the `TensorShape` representing the shape of the dense tensor.
|
||||
|
||||
Returns:
|
||||
A `TensorShape` object.
|
||||
"""
|
||||
return tensor_util.constant_value_as_shape(self._dense_shape)
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
"""The `Graph` that contains the index, value, and dense_shape tensors."""
|
||||
return self._indices.graph
|
||||
|
||||
def consumers(self):
|
||||
"""Returns a list of `Operation`s that consume this `SparseTensor`.
|
||||
|
||||
Returns:
|
||||
A list of `Operation`s.
|
||||
"""
|
||||
values_consumers = set(self._values.consumers())
|
||||
indices_consumers = set(self._indices.consumers())
|
||||
dense_shape_consumers = set(self._dense_shape.consumers())
|
||||
return list(values_consumers \
|
||||
.union(indices_consumers, dense_shape_consumers))
|
||||
|
||||
def __str__(self):
|
||||
return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
|
||||
self._indices, self._values, self._dense_shape)
|
||||
|
@ -21,8 +21,10 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@ -63,6 +65,18 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
|
||||
sparse_tensor.is_sparse(
|
||||
sparse_tensor.SparseTensorValue([[0]], [0], [1])))
|
||||
|
||||
def testConsumers(self):
|
||||
sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4])
|
||||
w = ops.convert_to_tensor(np.ones([4, 1], np.float32))
|
||||
out = sparse_ops.sparse_tensor_dense_matmul(sp, w)
|
||||
self.assertEqual(len(sp.consumers()), 1)
|
||||
self.assertEqual(sp.consumers()[0], out.op)
|
||||
|
||||
dense = sparse_ops.sparse_tensor_to_dense(sp)
|
||||
self.assertEqual(len(sp.consumers()), 2)
|
||||
self.assertTrue(dense.op in sp.consumers())
|
||||
self.assertTrue(out.op in sp.consumers())
|
||||
|
||||
|
||||
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -35,6 +35,8 @@ from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine.training_utils import weighted_masked_objective
|
||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
@ -386,6 +388,18 @@ class TrainingTest(test.TestCase):
|
||||
epochs=1, batch_size=2, validation_split=0.5)
|
||||
model.evaluate(test_inputs, test_outputs, batch_size=2)
|
||||
|
||||
def test_compile_with_sparse_placeholders(self):
|
||||
with self.test_session():
|
||||
input_layer = keras.layers.Input(shape=(10,), sparse=True)
|
||||
weights = variable_scope.get_variable(name='weights', shape=(10, 1))
|
||||
weights_mult = lambda x: sparse_ops.sparse_tensor_dense_matmul(x, weights)
|
||||
output_layer = keras.layers.Lambda(weights_mult)(input_layer)
|
||||
model = keras.Model([input_layer], output_layer)
|
||||
model.compile(
|
||||
loss='binary_crossentropy',
|
||||
optimizer=keras.optimizers.Adam(lr=0.0001),
|
||||
metrics=['accuracy'])
|
||||
|
||||
def test_that_trainable_disables_updates(self):
|
||||
val_a = np.random.random((10, 4))
|
||||
val_out = np.random.random((10, 4))
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond as smart_module
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util import nest
|
||||
@ -109,10 +110,10 @@ def get_reachable_from_inputs(inputs, targets=None):
|
||||
if isinstance(x, ops.Operation):
|
||||
outputs = x.outputs[:] or []
|
||||
outputs += x._control_outputs # pylint: disable=protected-access
|
||||
elif isinstance(x, ops.Tensor):
|
||||
outputs = x.consumers()
|
||||
elif isinstance(x, variables.Variable):
|
||||
outputs = [x.op]
|
||||
elif tensor_util.is_tensor(x):
|
||||
outputs = x.consumers()
|
||||
else:
|
||||
raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
|
||||
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "op"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "values"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -31,6 +35,10 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "consumers"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "eval"
|
||||
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "op"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "values"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -31,6 +35,10 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "consumers"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "eval"
|
||||
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user