Fix Keras Input layer with sparse=True

PiperOrigin-RevId: 209631832
This commit is contained in:
A. Unique TensorFlower 2018-08-21 11:39:12 -07:00 committed by TensorFlower Gardener
parent d81e875dd6
commit 3c74f977c6
6 changed files with 68 additions and 2 deletions

View File

@ -183,11 +183,32 @@ class SparseTensor(_TensorLike):
"""A 1-D Tensor of int64 representing the shape of the dense tensor."""
return self._dense_shape
@property
def shape(self):
"""Get the `TensorShape` representing the shape of the dense tensor.
Returns:
A `TensorShape` object.
"""
return tensor_util.constant_value_as_shape(self._dense_shape)
@property
def graph(self):
"""The `Graph` that contains the index, value, and dense_shape tensors."""
return self._indices.graph
def consumers(self):
"""Returns a list of `Operation`s that consume this `SparseTensor`.
Returns:
A list of `Operation`s.
"""
values_consumers = set(self._values.consumers())
indices_consumers = set(self._indices.consumers())
dense_shape_consumers = set(self._dense_shape.consumers())
return list(values_consumers \
.union(indices_consumers, dense_shape_consumers))
def __str__(self):
return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
self._indices, self._values, self._dense_shape)

View File

@ -21,8 +21,10 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import googletest
@ -63,6 +65,18 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
sparse_tensor.is_sparse(
sparse_tensor.SparseTensorValue([[0]], [0], [1])))
def testConsumers(self):
sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4])
w = ops.convert_to_tensor(np.ones([4, 1], np.float32))
out = sparse_ops.sparse_tensor_dense_matmul(sp, w)
self.assertEqual(len(sp.consumers()), 1)
self.assertEqual(sp.consumers()[0], out.op)
dense = sparse_ops.sparse_tensor_to_dense(sp)
self.assertEqual(len(sp.consumers()), 2)
self.assertTrue(dense.op in sp.consumers())
self.assertTrue(out.op in sp.consumers())
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):

View File

@ -35,6 +35,8 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine.training_utils import weighted_masked_objective
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@ -386,6 +388,18 @@ class TrainingTest(test.TestCase):
epochs=1, batch_size=2, validation_split=0.5)
model.evaluate(test_inputs, test_outputs, batch_size=2)
def test_compile_with_sparse_placeholders(self):
with self.test_session():
input_layer = keras.layers.Input(shape=(10,), sparse=True)
weights = variable_scope.get_variable(name='weights', shape=(10, 1))
weights_mult = lambda x: sparse_ops.sparse_tensor_dense_matmul(x, weights)
output_layer = keras.layers.Lambda(weights_mult)(input_layer)
model = keras.Model([input_layer], output_layer)
model.compile(
loss='binary_crossentropy',
optimizer=keras.optimizers.Adam(lr=0.0001),
metrics=['accuracy'])
def test_that_trainable_disables_updates(self):
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond as smart_module
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import nest
@ -109,10 +110,10 @@ def get_reachable_from_inputs(inputs, targets=None):
if isinstance(x, ops.Operation):
outputs = x.outputs[:] or []
outputs += x._control_outputs # pylint: disable=protected-access
elif isinstance(x, ops.Tensor):
outputs = x.consumers()
elif isinstance(x, variables.Variable):
outputs = [x.op]
elif tensor_util.is_tensor(x):
outputs = x.consumers()
else:
raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))

View File

@ -23,6 +23,10 @@ tf_class {
name: "op"
mtype: "<type \'property\'>"
}
member {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "values"
mtype: "<type \'property\'>"
@ -31,6 +35,10 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "consumers"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "eval"
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -23,6 +23,10 @@ tf_class {
name: "op"
mtype: "<type \'property\'>"
}
member {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "values"
mtype: "<type \'property\'>"
@ -31,6 +35,10 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "consumers"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "eval"
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "