diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py
index 8755be24c57..399726f82ef 100644
--- a/tensorflow/python/keras/engine/base_layer_utils.py
+++ b/tensorflow/python/keras/engine/base_layer_utils.py
@@ -213,7 +213,8 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
   for tensor in tensor_list:
     if getattr(tensor, '_keras_history', None) is not None:
       continue
-    if sparse_tensor.is_sparse(tensor):
+    if isinstance(
+        tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
       sparse_ops.append(tensor.op)
       continue
     if tf_utils.is_ragged(tensor):
diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py
index cbdf7b53e10..09fca11bd59 100644
--- a/tensorflow/python/keras/engine/base_preprocessing_layer.py
+++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py
@@ -265,8 +265,6 @@ def convert_to_list(values, sparse_default_value=None):
       values = K.get_session(values).run(values)
     values = values.to_list()
 
-  # TODO(momernick): Add a sparse_tensor.is_sparse() method to replace this
-  # check.
   if isinstance(values,
                 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
     if sparse_default_value is None: