From 3c74f977c6fc33486c21f5bdc8145e951474433e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 21 Aug 2018 11:39:12 -0700 Subject: [PATCH] Fix Keras Input layer with sparse=True PiperOrigin-RevId: 209631832 --- tensorflow/python/framework/sparse_tensor.py | 21 +++++++++++++++++++ .../python/framework/sparse_tensor_test.py | 14 +++++++++++++ .../python/keras/engine/training_test.py | 14 +++++++++++++ tensorflow/python/keras/utils/tf_utils.py | 5 +++-- .../golden/v1/tensorflow.-sparse-tensor.pbtxt | 8 +++++++ .../golden/v2/tensorflow.-sparse-tensor.pbtxt | 8 +++++++ 6 files changed, 68 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index a45581190fc..4823ba541d7 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -183,11 +183,32 @@ class SparseTensor(_TensorLike): """A 1-D Tensor of int64 representing the shape of the dense tensor.""" return self._dense_shape + @property + def shape(self): + """Get the `TensorShape` representing the shape of the dense tensor. + + Returns: + A `TensorShape` object. + """ + return tensor_util.constant_value_as_shape(self._dense_shape) + @property def graph(self): """The `Graph` that contains the index, value, and dense_shape tensors.""" return self._indices.graph + def consumers(self): + """Returns a list of `Operation`s that consume this `SparseTensor`. + + Returns: + A list of `Operation`s. + """ + values_consumers = set(self._values.consumers()) + indices_consumers = set(self._indices.consumers()) + dense_shape_consumers = set(self._dense_shape.consumers()) + return list(values_consumers \ + .union(indices_consumers, dense_shape_consumers)) + def __str__(self): return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % ( self._indices, self._values, self._dense_shape) diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py index c001fed3b05..2bcfbc17dfe 100644 --- a/tensorflow/python/framework/sparse_tensor_test.py +++ b/tensorflow/python/framework/sparse_tensor_test.py @@ -21,8 +21,10 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import googletest @@ -63,6 +65,18 @@ class SparseTensorTest(test_util.TensorFlowTestCase): sparse_tensor.is_sparse( sparse_tensor.SparseTensorValue([[0]], [0], [1]))) + def testConsumers(self): + sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4]) + w = ops.convert_to_tensor(np.ones([4, 1], np.float32)) + out = sparse_ops.sparse_tensor_dense_matmul(sp, w) + self.assertEqual(len(sp.consumers()), 1) + self.assertEqual(sp.consumers()[0], out.op) + + dense = sparse_ops.sparse_tensor_to_dense(sp) + self.assertEqual(len(sp.consumers()), 2) + self.assertTrue(dense.op in sp.consumers()) + self.assertTrue(out.op in sp.consumers()) + class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 8d835ed5a98..449e4028530 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -35,6 +35,8 @@ from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine.training_utils import weighted_masked_objective from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -386,6 +388,18 @@ class TrainingTest(test.TestCase): epochs=1, batch_size=2, validation_split=0.5) model.evaluate(test_inputs, test_outputs, batch_size=2) + def test_compile_with_sparse_placeholders(self): + with self.test_session(): + input_layer = keras.layers.Input(shape=(10,), sparse=True) + weights = variable_scope.get_variable(name='weights', shape=(10, 1)) + weights_mult = lambda x: sparse_ops.sparse_tensor_dense_matmul(x, weights) + output_layer = keras.layers.Lambda(weights_mult)(input_layer) + model = keras.Model([input_layer], output_layer) + model.compile( + loss='binary_crossentropy', + optimizer=keras.optimizers.Adam(lr=0.0001), + metrics=['accuracy']) + def test_that_trainable_disables_updates(self): val_a = np.random.random((10, 4)) val_out = np.random.random((10, 4)) diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 162e5b2cd65..cfdb3de2aa7 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond as smart_module from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.util import nest @@ -109,10 +110,10 @@ def get_reachable_from_inputs(inputs, targets=None): if isinstance(x, ops.Operation): outputs = x.outputs[:] or [] outputs += x._control_outputs # pylint: disable=protected-access - elif isinstance(x, ops.Tensor): - outputs = x.consumers() elif isinstance(x, variables.Variable): outputs = [x.op] + elif tensor_util.is_tensor(x): + outputs = x.consumers() else: raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x)) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt index eac236d4982..3add49e90d7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "op" mtype: "" } + member { + name: "shape" + mtype: "" + } member { name: "values" mtype: "" @@ -31,6 +35,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "consumers" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "eval" argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt index eac236d4982..3add49e90d7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "op" mtype: "" } + member { + name: "shape" + mtype: "" + } member { name: "values" mtype: "" @@ -31,6 +35,10 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "consumers" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "eval" argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "