Fix sparse_softmax_cross_entropy_with_logits for empty tensor
If the batch size is zero, we need to avoid calling into Eigen because Eigen will explode. Zero classes is an error. Change: 126359444
This commit is contained in:
parent
69abcd7ec2
commit
aabfe1d03e
@ -35,38 +35,42 @@ class SparseSoftmaxXentWithLogitsOp : public OpKernel {
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& logits_in = context->input(0);
|
||||
const Tensor& labels_in = context->input(1);
|
||||
OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(),
|
||||
const Tensor& logits = context->input(0);
|
||||
const Tensor& labels = context->input(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()),
|
||||
errors::InvalidArgument("logits must be 2-D, but got shape ",
|
||||
logits.shape().DebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()),
|
||||
errors::InvalidArgument("labels must be 1-D, but got shape ",
|
||||
labels.shape().DebugString()));
|
||||
OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"logits first dimension must match labels size. logits shape=",
|
||||
logits_in.shape().DebugString(), " labels shape=",
|
||||
labels_in.shape().DebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
|
||||
errors::InvalidArgument("logits must be 2-dimensional"));
|
||||
// As we already tested that both inputs have the same shape no need to
|
||||
// check that "labels" is a matrix too.
|
||||
|
||||
// loss is 1-D (one per example), and size is batch_size.
|
||||
"logits and labels must have the same first dimension, "
|
||||
"got logits shape ",
|
||||
logits.shape().DebugString(), " and labels shape ",
|
||||
labels.shape().DebugString()));
|
||||
OP_REQUIRES(context, logits.dim_size(1) > 0,
|
||||
errors::InvalidArgument(
|
||||
"Must have at least one class, but got logits shape ",
|
||||
logits.shape().DebugString()));
|
||||
|
||||
Tensor scratch;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
TensorShape({logits_in.dim_size(0)}),
|
||||
&scratch));
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
labels.shape(), &scratch));
|
||||
|
||||
Tensor* loss_out = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(
|
||||
0, TensorShape({logits_in.dim_size(0)}), &loss_out));
|
||||
context->allocate_output(0, labels.shape(), &loss_out));
|
||||
Tensor* back_out = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(1, logits_in.shape(), &back_out));
|
||||
context->allocate_output(1, logits.shape(), &back_out));
|
||||
|
||||
functor::SparseXentFunctor<Device, T, Index> functor;
|
||||
functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
|
||||
labels_in.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
|
||||
back_out->matrix<T>());
|
||||
if (logits.dim_size(0) > 0) {
|
||||
functor::SparseXentFunctor<Device, T, Index> functor;
|
||||
functor(context->eigen_device<Device>(), logits.matrix<T>(),
|
||||
labels.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
|
||||
back_out->matrix<T>());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -120,6 +120,13 @@ class SparseXentTest(tf.test.TestCase):
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
tf.constant(1.0), tf.constant(0))
|
||||
|
||||
def testLabelsPlaceholderScalar(self):
|
||||
with self.test_session():
|
||||
labels = tf.placeholder(np.int32)
|
||||
y = tf.nn.sparse_softmax_cross_entropy_with_logits([[7.]], labels)
|
||||
with self.assertRaisesOpError("labels must be 1-D"):
|
||||
y.eval(feed_dict={labels: 0})
|
||||
|
||||
def testVector(self):
|
||||
with self.test_session():
|
||||
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
@ -145,6 +152,9 @@ class SparseXentTest(tf.test.TestCase):
|
||||
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16),
|
||||
np.array([3, 0]).astype(label_dtype))
|
||||
|
||||
def testEmpty(self):
|
||||
self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32))
|
||||
|
||||
def testGradient(self):
|
||||
with self.test_session():
|
||||
l = tf.constant([3, 0, 1], name="l")
|
||||
|
Loading…
Reference in New Issue
Block a user