From 8e01ae829bb88ff197c2e6b8c3ad1668fd2b9fa5 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 12 Aug 2020 16:56:05 -0700
Subject: [PATCH] PR #41916: [TF2XLA] Add EuclideanNorm kernel

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/41916

PiperOrigin-RevId: 326343156
Change-Id: I9810a3301570bf5a25e97b3004fe0043f8ee01db
---
 .../compiler/jit/mark_for_compilation_pass.cc |  1 -
 tensorflow/compiler/tests/reduce_ops_test.py  | 22 +---------
 .../compiler/tf2xla/kernels/reduction_ops.cc  | 42 -------------------
 .../compiler/tf2xla/kernels/reduction_ops.h   |  4 --
 .../tf2xla/kernels/reduction_ops_common.cc    | 14 ++-----
 5 files changed, 4 insertions(+), 79 deletions(-)

diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 43619eca1fa..19eb61b6f72 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -1892,7 +1892,6 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
                                      "Einsum",
                                      "EmptyTensorList",
                                      "EnsureShape",
-                                     "EuclideanNorm",
                                      "ExtractImagePatches",
                                      "Igamma",
                                      "IgammaGradA",
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index a6844375c61..eb46c536e07 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -25,7 +25,6 @@ from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.compiler.tests import xla_test
-from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.ops import array_ops
@@ -51,8 +50,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
         with self.test_scope():
           a = array_ops.placeholder(dtype)
           index = array_ops.placeholder(index_dtype)
-          out = def_function.function(experimental_compile=True)(tf_reduce_fn)(
-              a, index)
+          out = tf_reduce_fn(a, index)
         result = sess.run(out, {a: test_input, index: [0]})
         self.assertAllClose(
             result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol)
@@ -181,24 +179,6 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
           'Axes contains duplicate dimension'):
         sess.run(out, {a: [10, 20, 30], index: [0, 0]})
 
-  def testReduceEuclideanNorm(self, index_dtype):
-
-    def reference_euclidean_norm(dtype, inp, axis):
-      inp = inp.astype(dtype)
-      return np.sqrt(np.sum(inp * np.conj(inp), axis)).astype(dtype)
-
-    for real_dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
-      self._testReduction(
-          math_ops.reduce_euclidean_norm,
-          functools.partial(reference_euclidean_norm, real_dtype), real_dtype,
-          self.REAL_DATA, index_dtype)
-
-    for complex_dtype in [np.complex64]:
-      self._testReduction(
-          math_ops.reduce_euclidean_norm,
-          functools.partial(reference_euclidean_norm, complex_dtype),
-          complex_dtype, self.COMPLEX_DATA, index_dtype)
-
 
 class ReduceOpPrecisionTest(xla_test.XLATestCase):
 
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index f95d58fd96a..4f63c0d1b66 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -16,15 +16,12 @@ limitations under the License.
 // XLA-specific reduction Ops.
 
 #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
-
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/lib/math.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/core/framework/kernel_def_builder.h"
 
 namespace tensorflow {
@@ -187,44 +184,5 @@ class AnyOp : public XlaReductionOp {
 REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"),
                 AnyOp);
 
-class EuclideanNormOp : public XlaReductionOp {
- public:
-  explicit EuclideanNormOp(OpKernelConstruction* ctx)
-      : XlaReductionOp(ctx,
-                       XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
-  xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
-    return xla::Zero(builder, xla_reduction_type_);
-  }
-
-  xla::XlaOp PreprocessInput(xla::XlaBuilder* /*builder*/,
-                             const xla::XlaOp& data) override {
-    return xla::Mul(data, MaybeConjugate(data, true));
-  }
-
-  void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
-                    const xla::XlaOp& scalar_rhs) override {
-    xla::Add(scalar_lhs, scalar_rhs);
-  }
-
-  xla::XlaOp BuildFinalizer(
-      xla::XlaBuilder* /*builder*/, const xla::XlaOp& input,
-      const xla::XlaOp& reduce_output,
-      const std::vector<int64>& dimensions_to_reduce) override {
-    if (xla::primitive_util::IsIntegralType(xla_reduction_type_)) {
-      // XLA only supports float and complex sqrt.
-      // Thus, cast integral type to F32 for computation.
-      return XlaHelpers::ConvertElementType(
-          xla::Sqrt(xla::ConvertElementType(reduce_output, xla::F32)),
-          input_type(0));
-    }
-    return XlaHelpers::ConvertElementType(xla::Sqrt(reduce_output),
-                                          input_type(0));
-  }
-};
-
-REGISTER_XLA_OP(
-    Name("EuclideanNorm").CompileTimeConstantInput("reduction_indices"),
-    EuclideanNormOp);
-
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
index 2091b496ddb..af716eab798 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -39,10 +39,6 @@ class XlaReductionOp : public XlaOpKernel {
   // Return the base case for the reduction.
   virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
 
-  // Preprocesses input before reduction.
-  virtual xla::XlaOp PreprocessInput(xla::XlaBuilder* builder,
-                                     const xla::XlaOp& data);
-
   // Implement the (scalar,scalar)->scalar lambda that should be
   // applied to each pair of elements to be reduced. The desired
   // computation should be added to 'builder' and
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 58d53dfea58..b4284a5498c 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -35,12 +35,6 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
       ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
 }
 
-// The default pre-processor directly returns the data. This can be overridden.
-xla::XlaOp XlaReductionOp::PreprocessInput(xla::XlaBuilder* /*builder*/,
-                                           const xla::XlaOp& data) {
-  return data;
-}
-
 // The default finalizer converts the results back into the input type. This can
 // be overridden.
 xla::XlaOp XlaReductionOp::BuildFinalizer(
@@ -117,8 +111,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
   xla::PrimitiveType type;
   TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
 
-  auto converted_input = xla::ConvertElementType(ctx->Input(0), type);
-  auto processed_input = PreprocessInput(b, converted_input);
+  auto data = xla::ConvertElementType(ctx->Input(0), type);
   // Call virtual method to get the initial value.
   auto initial = xla::ConvertElementType(InitialValue(b), type);
   // Make two scalar parameters of the desired type for the lambda.
@@ -128,9 +121,8 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
   BuildReducer(&r, rx, ry);
   xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
 
-  auto reduce =
-      xla::Reduce(processed_input, initial, reduction_computation, xla_axes);
-  auto finalized = BuildFinalizer(b, converted_input, reduce, xla_axes);
+  auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes);
+  auto finalized = BuildFinalizer(b, data, reduce, xla_axes);
   auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
   ctx->SetOutput(0, result);
 }