From aabfe1d03ec518269f1b28793f70e4a95eb52574 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Thu, 30 Jun 2016 15:49:37 -0800 Subject: [PATCH] Fix sparse_softmax_cross_entropy_with_logits for empty tensor If the batch size is zero, we need to avoid calling into Eigen because Eigen will explode. Zero classes is an error. Change: 126359444 --- tensorflow/core/kernels/sparse_xent_op.cc | 50 ++++++++++--------- .../kernel_tests/sparse_xent_op_test.py | 10 ++++ 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc index 34411c9bbb6..48124d20af9 100644 --- a/tensorflow/core/kernels/sparse_xent_op.cc +++ b/tensorflow/core/kernels/sparse_xent_op.cc @@ -35,38 +35,42 @@ class SparseSoftmaxXentWithLogitsOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - const Tensor& logits_in = context->input(0); - const Tensor& labels_in = context->input(1); - OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(), + const Tensor& logits = context->input(0); + const Tensor& labels = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()), + errors::InvalidArgument("logits must be 2-D, but got shape ", + logits.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()), + errors::InvalidArgument("labels must be 1-D, but got shape ", + labels.shape().DebugString())); + OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0), errors::InvalidArgument( - "logits first dimension must match labels size. logits shape=", - logits_in.shape().DebugString(), " labels shape=", - labels_in.shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), - errors::InvalidArgument("logits must be 2-dimensional")); - // As we already tested that both inputs have the same shape no need to - // check that "labels" is a matrix too. - - // loss is 1-D (one per example), and size is batch_size. + "logits and labels must have the same first dimension, " + "got logits shape ", + logits.shape().DebugString(), " and labels shape ", + labels.shape().DebugString())); + OP_REQUIRES(context, logits.dim_size(1) > 0, + errors::InvalidArgument( + "Must have at least one class, but got logits shape ", + logits.shape().DebugString())); Tensor scratch; - OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum::value, - TensorShape({logits_in.dim_size(0)}), - &scratch)); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + labels.shape(), &scratch)); Tensor* loss_out = nullptr; OP_REQUIRES_OK(context, - context->allocate_output( - 0, TensorShape({logits_in.dim_size(0)}), &loss_out)); + context->allocate_output(0, labels.shape(), &loss_out)); Tensor* back_out = nullptr; OP_REQUIRES_OK(context, - context->allocate_output(1, logits_in.shape(), &back_out)); + context->allocate_output(1, logits.shape(), &back_out)); - functor::SparseXentFunctor functor; - functor(context->eigen_device(), logits_in.matrix(), - labels_in.vec(), scratch.vec(), loss_out->vec(), - back_out->matrix()); + if (logits.dim_size(0) > 0) { + functor::SparseXentFunctor functor; + functor(context->eigen_device(), logits.matrix(), + labels.vec(), scratch.vec(), loss_out->vec(), + back_out->matrix()); + } } }; diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py index eb6bdff8b5a..ea379fbac01 100644 --- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py @@ -120,6 +120,13 @@ class SparseXentTest(tf.test.TestCase): tf.nn.sparse_softmax_cross_entropy_with_logits( tf.constant(1.0), tf.constant(0)) + def testLabelsPlaceholderScalar(self): + with self.test_session(): + labels = tf.placeholder(np.int32) + y = tf.nn.sparse_softmax_cross_entropy_with_logits([[7.]], labels) + with self.assertRaisesOpError("labels must be 1-D"): + y.eval(feed_dict={labels: 0}) + def testVector(self): with self.test_session(): loss = tf.nn.sparse_softmax_cross_entropy_with_logits( @@ -145,6 +152,9 @@ class SparseXentTest(tf.test.TestCase): np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16), np.array([3, 0]).astype(label_dtype)) + def testEmpty(self): + self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32)) + def testGradient(self): with self.test_session(): l = tf.constant([3, 0, 1], name="l")