diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h index 164359f601a..3c62dcfc081 100644 --- a/tensorflow/core/kernels/reduction_ops.h +++ b/tensorflow/core/kernels/reduction_ops.h @@ -26,6 +26,11 @@ limitations under the License. namespace tensorflow { namespace functor { +template <typename Reducer> +struct ReducerTraits { + enum { IsScalarIdentity = true }; +}; + // Dummy class used for template specialization for mean reduction, which is // accomplished by SumReducer and on-the-fly division by the reduction factor. template <typename Scalar> @@ -39,6 +44,11 @@ struct EuclideanNormReducer { Scalar initialize() const { return Scalar(0); } }; +template <typename Scalar> +struct ReducerTraits<EuclideanNormReducer<Scalar>> { + enum { IsScalarIdentity = false }; +}; + template <typename Device, typename OUT_T, typename IN_T, typename ReductionAxes, typename Reducer> struct ReduceEigenImpl { diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index c6c36ec29a7..072699288db 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -155,12 +155,12 @@ class ReductionOp : public OpKernel { OP_REQUIRES_OK(ctx, helper.Simplify(data, axes, keep_dims_)); CHECK_GE(helper.ndims(), 0); - if (helper.ndims() == 0 || - (helper.ndims() == 1 && !helper.reduce_first_axis())) { - // Special case. Reduces nothing. It is unclear why this is - // necessary, but tests fail without it. Look into why this - // case occurs. + bool is_scalar_identity = functor::ReducerTraits<Reducer>::IsScalarIdentity; + bool is_trivial = helper.ndims() == 0 || + (helper.ndims() == 1 && !helper.reduce_first_axis()); + if (is_scalar_identity && is_trivial) { Tensor out; + // Special case. Reduces nothing and does not alter the input values. if (!out.CopyFrom(data, helper.out_shape())) { ctx->SetStatus(errors::Internal("Error during reduction copy.")); } @@ -172,73 +172,83 @@ class ReductionOp : public OpKernel { // output(0) because it is returned as output(0) in the end. const AllocatorAttributes alloc_attr = ctx->output_alloc_attr(0); - // A temporary tensor whose size matches the size of the reduced - // output. Tensor tmp_out; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(ctx->expected_output_dtype(0), - helper.out_reshape(), &tmp_out, alloc_attr)); - typedef functor::ReduceFunctor<Device, Reducer> Functor; Constants<Device> constants; const Device& d = ctx->eigen_device<Device>(); Reducer reducer; - if (tmp_out.NumElements() == 0) { - // Nothing to do, fall through to final reshaping. - } else if (data.NumElements() == 0) { - // Degenerate reduction where the input is empty but the output is - // nonempty (thus tmp_out.NumElements() > 0), and we must fill the output - // with identity elements. Example: tf.reduce_sum(tf.zeros((0, 3)), [0]). - // Eigen sometimes crashes in this case, so we do it manually. - Functor::FillIdentity(d, tmp_out.flat<T>(), reducer); - } else if ((helper.ndims() == 1) && helper.reduce_first_axis()) { - // Reduce to a scalar. - Functor::Reduce(ctx, helper.out<T, 0>(&tmp_out), helper.in<T, 1>(data), - constants.kZero, reducer); - } else if ((helper.ndims() == 2) && helper.reduce_first_axis()) { - // Can be viewed as a reduction of a matrix along 1st dimension. - Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data), - constants.kZero, reducer); - } else if ((helper.ndims() == 2) && !helper.reduce_first_axis()) { - // Can be viewed as a reduction of a matrix along 2nd dimension. - Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data), - constants.kOne, reducer); - } else if ((helper.ndims() == 3) && helper.reduce_first_axis()) { - // Can be viewed as a reduction of a 3D tensor along 1st and 3rd - // dimensions. - Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 3>(data), - constants.kZeroTwo, reducer); - } else if ((helper.ndims() == 3) && !helper.reduce_first_axis()) { - // Can be viewed as a reduction of a 3D tensor along 2nd dimension. - Functor::Reduce(ctx, helper.out<T, 2>(&tmp_out), helper.in<T, 3>(data), - constants.kOne, reducer); - } else { - // If we don't hit one of the cases above, transpose the data so that - // all reduced dimensions are last and reuse the 2-D -> 1-D case. - Tensor data_reshaped; - CHECK(data_reshaped.CopyFrom(data, helper.data_reshape())); - Tensor shuffled; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, - helper.shuffled_shape(), &shuffled, - alloc_attr)); - OP_REQUIRES_OK( - ctx, DoTranspose(d, data_reshaped, helper.permutation(), &shuffled)); - const int64 unreduced = tmp_out.NumElements(); - const int64 reduced = shuffled.NumElements() / unreduced; - const Tensor& const_shuffled = shuffled; + if (data.NumElements() > 0 && is_trivial && !is_scalar_identity) { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(ctx->expected_output_dtype(0), + TensorShape({data.NumElements()}), + &tmp_out, alloc_attr)); Functor::Reduce(ctx, tmp_out.flat<T>(), - const_shuffled.shaped<T, 2>({unreduced, reduced}), - constants.kOne, reducer); + data.shaped<T, 2>({1, data.NumElements()}), + constants.kZero, reducer); + } else { + // A temporary tensor whose size matches the size of the reduced + // output. + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(ctx->expected_output_dtype(0), + helper.out_reshape(), &tmp_out, alloc_attr)); + + if (tmp_out.NumElements() == 0) { + // Nothing to do, fall through to final reshaping. + } else if (data.NumElements() == 0) { + // Degenerate reduction where the input is empty but the output is + // nonempty (thus tmp_out.NumElements() > 0), and we must fill the + // output with identity elements. Example: tf.reduce_sum(tf.zeros((0, + // 3)), [0]). Eigen sometimes crashes in this case, so we do it + // manually. + Functor::FillIdentity(d, tmp_out.flat<T>(), reducer); + } else if ((helper.ndims() == 1) && helper.reduce_first_axis()) { + // Reduce to a scalar. + Functor::Reduce(ctx, helper.out<T, 0>(&tmp_out), helper.in<T, 1>(data), + constants.kZero, reducer); + } else if ((helper.ndims() == 2) && helper.reduce_first_axis()) { + // Can be viewed as a reduction of a matrix along 1st dimension. + Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data), + constants.kZero, reducer); + } else if ((helper.ndims() == 2) && !helper.reduce_first_axis()) { + // Can be viewed as a reduction of a matrix along 2nd dimension. + Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data), + constants.kOne, reducer); + } else if ((helper.ndims() == 3) && helper.reduce_first_axis()) { + // Can be viewed as a reduction of a 3D tensor along 1st and 3rd + // dimensions. + Functor::Reduce(ctx, helper.out<T, 1>(&tmp_out), helper.in<T, 3>(data), + constants.kZeroTwo, reducer); + } else if ((helper.ndims() == 3) && !helper.reduce_first_axis()) { + // Can be viewed as a reduction of a 3D tensor along 2nd dimension. + Functor::Reduce(ctx, helper.out<T, 2>(&tmp_out), helper.in<T, 3>(data), + constants.kOne, reducer); + } else { + // If we don't hit one of the cases above, transpose the data so that + // all reduced dimensions are last and reuse the 2-D -> 1-D case. + Tensor data_reshaped; + OP_REQUIRES(ctx, data_reshaped.CopyFrom(data, helper.data_reshape()), + errors::Internal("Error during reduction copy.")); + Tensor shuffled; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, + helper.shuffled_shape(), + &shuffled, alloc_attr)); + OP_REQUIRES_OK(ctx, DoTranspose(d, data_reshaped, helper.permutation(), + &shuffled)); + const int64 unreduced = tmp_out.NumElements(); + const int64 reduced = shuffled.NumElements() / unreduced; + const Tensor& const_shuffled = shuffled; + Functor::Reduce(ctx, tmp_out.flat<T>(), + const_shuffled.shaped<T, 2>({unreduced, reduced}), + constants.kOne, reducer); + } } // Set the real output using the contents of the reduction but the // real expected output shape. The number of elements should // match between the two shapes. Tensor out; - if (!out.CopyFrom(tmp_out, helper.out_shape())) { - ctx->SetStatus(errors::Internal("Error during reduction copy.")); - } + OP_REQUIRES(ctx, out.CopyFrom(tmp_out, helper.out_shape()), + errors::Internal("Error during reduction copy.")); ctx->set_output(0, out); } diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index ca052ab9445..acc66b7c3e6 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -497,11 +497,8 @@ class EuclideanNormReductionTest(BaseReductionTest): if isinstance(reduction_axes, list) or isinstance(reduction_axes, np.ndarray): reduction_axes = tuple(reduction_axes) - if reduction_axes is None or reduction_axes != tuple(): - np_fro = np.sqrt( - np.sum(x * np.conj(x), axis=reduction_axes, keepdims=keepdims)) - else: - np_fro = x + np_fro = np.sqrt( + np.sum(x * np.conj(x), axis=reduction_axes, keepdims=keepdims)) if np.issubdtype(x.dtype, np.integer): np_fro = np.floor(np_fro) return np_fro @@ -522,6 +519,12 @@ class EuclideanNormReductionTest(BaseReductionTest): np_arr = np.array([special_value_x, special_value_y]).astype(dtype) self._compareAll(np_arr, None) + @test_util.run_deprecated_v1 + def testSingleton(self): + for dtype in [np.float32, np.float64]: + np_arr = np.array([-1.]).astype(dtype) + self._compareAll(np_arr, None) + @test_util.run_deprecated_v1 def testInt32(self): for rank in range(1, _MAX_RANK + 1):