Fix Keras Input layer with sparse=True
PiperOrigin-RevId: 209631832
This commit is contained in:
parent
d81e875dd6
commit
3c74f977c6
@ -183,11 +183,32 @@ class SparseTensor(_TensorLike):
|
|||||||
"""A 1-D Tensor of int64 representing the shape of the dense tensor."""
|
"""A 1-D Tensor of int64 representing the shape of the dense tensor."""
|
||||||
return self._dense_shape
|
return self._dense_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
"""Get the `TensorShape` representing the shape of the dense tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `TensorShape` object.
|
||||||
|
"""
|
||||||
|
return tensor_util.constant_value_as_shape(self._dense_shape)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def graph(self):
|
def graph(self):
|
||||||
"""The `Graph` that contains the index, value, and dense_shape tensors."""
|
"""The `Graph` that contains the index, value, and dense_shape tensors."""
|
||||||
return self._indices.graph
|
return self._indices.graph
|
||||||
|
|
||||||
|
def consumers(self):
|
||||||
|
"""Returns a list of `Operation`s that consume this `SparseTensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of `Operation`s.
|
||||||
|
"""
|
||||||
|
values_consumers = set(self._values.consumers())
|
||||||
|
indices_consumers = set(self._indices.consumers())
|
||||||
|
dense_shape_consumers = set(self._dense_shape.consumers())
|
||||||
|
return list(values_consumers \
|
||||||
|
.union(indices_consumers, dense_shape_consumers))
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
|
return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
|
||||||
self._indices, self._values, self._dense_shape)
|
self._indices, self._values, self._dense_shape)
|
||||||
|
@ -21,8 +21,10 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
@ -63,6 +65,18 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
|
|||||||
sparse_tensor.is_sparse(
|
sparse_tensor.is_sparse(
|
||||||
sparse_tensor.SparseTensorValue([[0]], [0], [1])))
|
sparse_tensor.SparseTensorValue([[0]], [0], [1])))
|
||||||
|
|
||||||
|
def testConsumers(self):
|
||||||
|
sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4])
|
||||||
|
w = ops.convert_to_tensor(np.ones([4, 1], np.float32))
|
||||||
|
out = sparse_ops.sparse_tensor_dense_matmul(sp, w)
|
||||||
|
self.assertEqual(len(sp.consumers()), 1)
|
||||||
|
self.assertEqual(sp.consumers()[0], out.op)
|
||||||
|
|
||||||
|
dense = sparse_ops.sparse_tensor_to_dense(sp)
|
||||||
|
self.assertEqual(len(sp.consumers()), 2)
|
||||||
|
self.assertTrue(dense.op in sp.consumers())
|
||||||
|
self.assertTrue(out.op in sp.consumers())
|
||||||
|
|
||||||
|
|
||||||
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
|
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@ -35,6 +35,8 @@ from tensorflow.python.keras import testing_utils
|
|||||||
from tensorflow.python.keras.engine.training_utils import weighted_masked_objective
|
from tensorflow.python.keras.engine.training_utils import weighted_masked_objective
|
||||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||||
@ -386,6 +388,18 @@ class TrainingTest(test.TestCase):
|
|||||||
epochs=1, batch_size=2, validation_split=0.5)
|
epochs=1, batch_size=2, validation_split=0.5)
|
||||||
model.evaluate(test_inputs, test_outputs, batch_size=2)
|
model.evaluate(test_inputs, test_outputs, batch_size=2)
|
||||||
|
|
||||||
|
def test_compile_with_sparse_placeholders(self):
|
||||||
|
with self.test_session():
|
||||||
|
input_layer = keras.layers.Input(shape=(10,), sparse=True)
|
||||||
|
weights = variable_scope.get_variable(name='weights', shape=(10, 1))
|
||||||
|
weights_mult = lambda x: sparse_ops.sparse_tensor_dense_matmul(x, weights)
|
||||||
|
output_layer = keras.layers.Lambda(weights_mult)(input_layer)
|
||||||
|
model = keras.Model([input_layer], output_layer)
|
||||||
|
model.compile(
|
||||||
|
loss='binary_crossentropy',
|
||||||
|
optimizer=keras.optimizers.Adam(lr=0.0001),
|
||||||
|
metrics=['accuracy'])
|
||||||
|
|
||||||
def test_that_trainable_disables_updates(self):
|
def test_that_trainable_disables_updates(self):
|
||||||
val_a = np.random.random((10, 4))
|
val_a = np.random.random((10, 4))
|
||||||
val_out = np.random.random((10, 4))
|
val_out = np.random.random((10, 4))
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import smart_cond as smart_module
|
from tensorflow.python.framework import smart_cond as smart_module
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -109,10 +110,10 @@ def get_reachable_from_inputs(inputs, targets=None):
|
|||||||
if isinstance(x, ops.Operation):
|
if isinstance(x, ops.Operation):
|
||||||
outputs = x.outputs[:] or []
|
outputs = x.outputs[:] or []
|
||||||
outputs += x._control_outputs # pylint: disable=protected-access
|
outputs += x._control_outputs # pylint: disable=protected-access
|
||||||
elif isinstance(x, ops.Tensor):
|
|
||||||
outputs = x.consumers()
|
|
||||||
elif isinstance(x, variables.Variable):
|
elif isinstance(x, variables.Variable):
|
||||||
outputs = [x.op]
|
outputs = [x.op]
|
||||||
|
elif tensor_util.is_tensor(x):
|
||||||
|
outputs = x.consumers()
|
||||||
else:
|
else:
|
||||||
raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
|
raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
|
||||||
|
|
||||||
|
@ -23,6 +23,10 @@ tf_class {
|
|||||||
name: "op"
|
name: "op"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "shape"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "values"
|
name: "values"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -31,6 +35,10 @@ tf_class {
|
|||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "consumers"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "eval"
|
name: "eval"
|
||||||
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
@ -23,6 +23,10 @@ tf_class {
|
|||||||
name: "op"
|
name: "op"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "shape"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "values"
|
name: "values"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -31,6 +35,10 @@ tf_class {
|
|||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "consumers"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "eval"
|
name: "eval"
|
||||||
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user