diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
index 34411c9bbb6..48124d20af9 100644
--- a/tensorflow/core/kernels/sparse_xent_op.cc
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -35,38 +35,42 @@ class SparseSoftmaxXentWithLogitsOp : public OpKernel {
       : OpKernel(context) {}
 
   void Compute(OpKernelContext* context) override {
-    const Tensor& logits_in = context->input(0);
-    const Tensor& labels_in = context->input(1);
-    OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(),
+    const Tensor& logits = context->input(0);
+    const Tensor& labels = context->input(1);
+    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()),
+                errors::InvalidArgument("logits must be 2-D, but got shape ",
+                                        logits.shape().DebugString()));
+    OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()),
+                errors::InvalidArgument("labels must be 1-D, but got shape ",
+                                        labels.shape().DebugString()));
+    OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0),
                 errors::InvalidArgument(
-                    "logits first dimension must match labels size.  logits shape=",
-                    logits_in.shape().DebugString(), " labels shape=",
-                    labels_in.shape().DebugString()));
-    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
-                errors::InvalidArgument("logits must be 2-dimensional"));
-    // As we already tested that both inputs have the same shape no need to
-    // check that "labels" is a matrix too.
-
-    // loss is 1-D (one per example), and size is batch_size.
+                    "logits and labels must have the same first dimension, "
+                    "got logits shape ",
+                    logits.shape().DebugString(), " and labels shape ",
+                    labels.shape().DebugString()));
+    OP_REQUIRES(context, logits.dim_size(1) > 0,
+                errors::InvalidArgument(
+                    "Must have at least one class, but got logits shape ",
+                    logits.shape().DebugString()));
 
     Tensor scratch;
-    OP_REQUIRES_OK(
-        context, context->allocate_temp(DataTypeToEnum<T>::value,
-                                        TensorShape({logits_in.dim_size(0)}),
-                                        &scratch));
+    OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+                                                   labels.shape(), &scratch));
 
     Tensor* loss_out = nullptr;
     OP_REQUIRES_OK(context,
-                   context->allocate_output(
-                       0, TensorShape({logits_in.dim_size(0)}), &loss_out));
+                   context->allocate_output(0, labels.shape(), &loss_out));
     Tensor* back_out = nullptr;
     OP_REQUIRES_OK(context,
-                   context->allocate_output(1, logits_in.shape(), &back_out));
+                   context->allocate_output(1, logits.shape(), &back_out));
 
-    functor::SparseXentFunctor<Device, T, Index> functor;
-    functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
-            labels_in.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
-            back_out->matrix<T>());
+    if (logits.dim_size(0) > 0) {
+      functor::SparseXentFunctor<Device, T, Index> functor;
+      functor(context->eigen_device<Device>(), logits.matrix<T>(),
+              labels.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
+              back_out->matrix<T>());
+    }
   }
 };
 
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index eb6bdff8b5a..ea379fbac01 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -120,6 +120,13 @@ class SparseXentTest(tf.test.TestCase):
         tf.nn.sparse_softmax_cross_entropy_with_logits(
             tf.constant(1.0), tf.constant(0))
 
+  def testLabelsPlaceholderScalar(self):
+    with self.test_session():
+      labels = tf.placeholder(np.int32)
+      y = tf.nn.sparse_softmax_cross_entropy_with_logits([[7.]], labels)
+      with self.assertRaisesOpError("labels must be 1-D"):
+        y.eval(feed_dict={labels: 0})
+
   def testVector(self):
     with self.test_session():
       loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
@@ -145,6 +152,9 @@ class SparseXentTest(tf.test.TestCase):
           np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16),
           np.array([3, 0]).astype(label_dtype))
 
+  def testEmpty(self):
+    self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32))
+
   def testGradient(self):
     with self.test_session():
       l = tf.constant([3, 0, 1], name="l")