Merge pull request #41916 from WindQAQ:tf2xla-euclidean-norm

PiperOrigin-RevId: 326304916
Change-Id: I1dcb9303e78c0f198de82cb1f0db9734469c7565
This commit is contained in:
TensorFlower Gardener 2020-08-12 14:03:21 -07:00
commit d487959ffc
5 changed files with 79 additions and 4 deletions

View File

@ -1892,6 +1892,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"Einsum",
"EmptyTensorList",
"EnsureShape",
"EuclideanNorm",
"ExtractImagePatches",
"Igamma",
"IgammaGradA",

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
@ -50,7 +51,8 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
with self.test_scope():
a = array_ops.placeholder(dtype)
index = array_ops.placeholder(index_dtype)
out = tf_reduce_fn(a, index)
out = def_function.function(experimental_compile=True)(tf_reduce_fn)(
a, index)
result = sess.run(out, {a: test_input, index: [0]})
self.assertAllClose(
result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol)
@ -179,6 +181,24 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
'Axes contains duplicate dimension'):
sess.run(out, {a: [10, 20, 30], index: [0, 0]})
def testReduceEuclideanNorm(self, index_dtype):
def reference_euclidean_norm(dtype, inp, axis):
inp = inp.astype(dtype)
return np.sqrt(np.sum(inp * np.conj(inp), axis)).astype(dtype)
for real_dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
self._testReduction(
math_ops.reduce_euclidean_norm,
functools.partial(reference_euclidean_norm, real_dtype), real_dtype,
self.REAL_DATA, index_dtype)
for complex_dtype in [np.complex64]:
self._testReduction(
math_ops.reduce_euclidean_norm,
functools.partial(reference_euclidean_norm, complex_dtype),
complex_dtype, self.COMPLEX_DATA, index_dtype)
class ReduceOpPrecisionTest(xla_test.XLATestCase):

View File

@ -16,12 +16,15 @@ limitations under the License.
// XLA-specific reduction Ops.
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@ -184,5 +187,44 @@ class AnyOp : public XlaReductionOp {
REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"),
AnyOp);
class EuclideanNormOp : public XlaReductionOp {
public:
explicit EuclideanNormOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return xla::Zero(builder, xla_reduction_type_);
}
xla::XlaOp PreprocessInput(xla::XlaBuilder* /*builder*/,
const xla::XlaOp& data) override {
return xla::Mul(data, MaybeConjugate(data, true));
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
xla::Add(scalar_lhs, scalar_rhs);
}
xla::XlaOp BuildFinalizer(
xla::XlaBuilder* /*builder*/, const xla::XlaOp& input,
const xla::XlaOp& reduce_output,
const std::vector<int64>& dimensions_to_reduce) override {
if (xla::primitive_util::IsIntegralType(xla_reduction_type_)) {
// XLA only supports float and complex sqrt.
// Thus, cast integral type to F32 for computation.
return XlaHelpers::ConvertElementType(
xla::Sqrt(xla::ConvertElementType(reduce_output, xla::F32)),
input_type(0));
}
return XlaHelpers::ConvertElementType(xla::Sqrt(reduce_output),
input_type(0));
}
};
REGISTER_XLA_OP(
Name("EuclideanNorm").CompileTimeConstantInput("reduction_indices"),
EuclideanNormOp);
} // namespace
} // namespace tensorflow

View File

@ -39,6 +39,10 @@ class XlaReductionOp : public XlaOpKernel {
// Return the base case for the reduction.
virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
// Preprocesses input before reduction.
virtual xla::XlaOp PreprocessInput(xla::XlaBuilder* builder,
const xla::XlaOp& data);
// Implement the (scalar,scalar)->scalar lambda that should be
// applied to each pair of elements to be reduced. The desired
// computation should be added to 'builder' and

View File

@ -35,6 +35,12 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
}
// The default pre-processor directly returns the data. This can be overridden.
xla::XlaOp XlaReductionOp::PreprocessInput(xla::XlaBuilder* /*builder*/,
const xla::XlaOp& data) {
return data;
}
// The default finalizer converts the results back into the input type. This can
// be overridden.
xla::XlaOp XlaReductionOp::BuildFinalizer(
@ -111,7 +117,8 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
auto data = xla::ConvertElementType(ctx->Input(0), type);
auto converted_input = xla::ConvertElementType(ctx->Input(0), type);
auto processed_input = PreprocessInput(b, converted_input);
// Call virtual method to get the initial value.
auto initial = xla::ConvertElementType(InitialValue(b), type);
// Make two scalar parameters of the desired type for the lambda.
@ -121,8 +128,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
BuildReducer(&r, rx, ry);
xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes);
auto finalized = BuildFinalizer(b, data, reduce, xla_axes);
auto reduce =
xla::Reduce(processed_input, initial, reduction_computation, xla_axes);
auto finalized = BuildFinalizer(b, converted_input, reduce, xla_axes);
auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
ctx->SetOutput(0, result);
}