diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 8755be24c57..399726f82ef 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -213,7 +213,8 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers): for tensor in tensor_list: if getattr(tensor, '_keras_history', None) is not None: continue - if sparse_tensor.is_sparse(tensor): + if isinstance( + tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): sparse_ops.append(tensor.op) continue if tf_utils.is_ragged(tensor): diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py index cbdf7b53e10..09fca11bd59 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py @@ -265,8 +265,6 @@ def convert_to_list(values, sparse_default_value=None): values = K.get_session(values).run(values) values = values.to_list() - # TODO(momernick): Add a sparse_tensor.is_sparse() method to replace this - # check. if isinstance(values, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): if sparse_default_value is None: