From 4e9d0b23aad82817527167cdfd8613567ca64d9c Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 30 Jul 2020 17:27:50 -0700 Subject: [PATCH] Add EuclideanNorm kernel Update allowlist Update allowlist Change ProcessData to ProcessInput Change ProcessInput to PreprocessInput Pass input to finalizer Change op name --- .../compiler/jit/mark_for_compilation_pass.cc | 1 + tensorflow/compiler/tests/reduce_ops_test.py | 24 ++++++++++- .../compiler/tf2xla/kernels/reduction_ops.cc | 42 +++++++++++++++++++ .../compiler/tf2xla/kernels/reduction_ops.h | 4 ++ .../tf2xla/kernels/reduction_ops_common.cc | 14 +++++-- 5 files changed, 81 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 19eb61b6f72..43619eca1fa 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1892,6 +1892,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "Einsum", "EmptyTensorList", "EnsureShape", + "EuclideanNorm", "ExtractImagePatches", "Igamma", "IgammaGradA", diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index eb46c536e07..5199cb972ee 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -25,6 +25,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -50,7 +51,8 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): with self.test_scope(): a = array_ops.placeholder(dtype) index = array_ops.placeholder(index_dtype) - out = tf_reduce_fn(a, index) + out = def_function.function(experimental_compile=True)( + tf_reduce_fn)(a, index) result = sess.run(out, {a: test_input, index: [0]}) self.assertAllClose( result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol) @@ -179,6 +181,26 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): 'Axes contains duplicate dimension'): sess.run(out, {a: [10, 20, 30], index: [0, 0]}) + def testReduceEuclideanNorm(self, index_dtype): + def reference_euclidean_norm(dtype, inp, axis): + inp = inp.astype(dtype) + return np.sqrt(np.sum(inp * np.conj(inp), axis)).astype(dtype) + + for real_dtype in [np.int32, np.int64, np.float16, + np.float32, np.float64]: + self._testReduction(math_ops.reduce_euclidean_norm, + functools.partial( + reference_euclidean_norm, real_dtype), + real_dtype, + self.REAL_DATA, index_dtype) + + for complex_dtype in [np.complex64, np.complex128]: + self._testReduction(math_ops.reduce_euclidean_norm, + functools.partial( + reference_euclidean_norm, complex_dtype), + complex_dtype, + self.COMPLEX_DATA, index_dtype) + class ReduceOpPrecisionTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 4f63c0d1b66..f95d58fd96a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -16,12 +16,15 @@ limitations under the License. // XLA-specific reduction Ops. #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -184,5 +187,44 @@ class AnyOp : public XlaReductionOp { REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"), AnyOp); +class EuclideanNormOp : public XlaReductionOp { + public: + explicit EuclideanNormOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::Zero(builder, xla_reduction_type_); + } + + xla::XlaOp PreprocessInput(xla::XlaBuilder* /*builder*/, + const xla::XlaOp& data) override { + return xla::Mul(data, MaybeConjugate(data, true)); + } + + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { + xla::Add(scalar_lhs, scalar_rhs); + } + + xla::XlaOp BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce) override { + if (xla::primitive_util::IsIntegralType(xla_reduction_type_)) { + // XLA only supports float and complex sqrt. + // Thus, cast integral type to F32 for computation. + return XlaHelpers::ConvertElementType( + xla::Sqrt(xla::ConvertElementType(reduce_output, xla::F32)), + input_type(0)); + } + return XlaHelpers::ConvertElementType(xla::Sqrt(reduce_output), + input_type(0)); + } +}; + +REGISTER_XLA_OP( + Name("EuclideanNorm").CompileTimeConstantInput("reduction_indices"), + EuclideanNormOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index af716eab798..2091b496ddb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -39,6 +39,10 @@ class XlaReductionOp : public XlaOpKernel { // Return the base case for the reduction. virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; + // Preprocesses input before reduction. + virtual xla::XlaOp PreprocessInput(xla::XlaBuilder* builder, + const xla::XlaOp& data); + // Implement the (scalar,scalar)->scalar lambda that should be // applied to each pair of elements to be reduced. The desired // computation should be added to 'builder' and diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index b4284a5498c..58d53dfea58 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -35,6 +35,12 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } +// The default pre-processor directly returns the data. This can be overridden. +xla::XlaOp XlaReductionOp::PreprocessInput(xla::XlaBuilder* /*builder*/, + const xla::XlaOp& data) { + return data; +} + // The default finalizer converts the results back into the input type. This can // be overridden. xla::XlaOp XlaReductionOp::BuildFinalizer( @@ -111,7 +117,8 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); - auto data = xla::ConvertElementType(ctx->Input(0), type); + auto converted_input = xla::ConvertElementType(ctx->Input(0), type); + auto processed_input = PreprocessInput(b, converted_input); // Call virtual method to get the initial value. auto initial = xla::ConvertElementType(InitialValue(b), type); // Make two scalar parameters of the desired type for the lambda. @@ -121,8 +128,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { BuildReducer(&r, rx, ry); xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); - auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); - auto finalized = BuildFinalizer(b, data, reduce, xla_axes); + auto reduce = + xla::Reduce(processed_input, initial, reduction_computation, xla_axes); + auto finalized = BuildFinalizer(b, converted_input, reduce, xla_axes); auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); }