Merge pull request #41916 from WindQAQ:tf2xla-euclidean-norm
PiperOrigin-RevId: 326304916 Change-Id: I1dcb9303e78c0f198de82cb1f0db9734469c7565
This commit is contained in:
commit
d487959ffc
@ -1892,6 +1892,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"Einsum",
|
||||
"EmptyTensorList",
|
||||
"EnsureShape",
|
||||
"EuclideanNorm",
|
||||
"ExtractImagePatches",
|
||||
"Igamma",
|
||||
"IgammaGradA",
|
||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -50,7 +51,8 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
with self.test_scope():
|
||||
a = array_ops.placeholder(dtype)
|
||||
index = array_ops.placeholder(index_dtype)
|
||||
out = tf_reduce_fn(a, index)
|
||||
out = def_function.function(experimental_compile=True)(tf_reduce_fn)(
|
||||
a, index)
|
||||
result = sess.run(out, {a: test_input, index: [0]})
|
||||
self.assertAllClose(
|
||||
result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol)
|
||||
@ -179,6 +181,24 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
'Axes contains duplicate dimension'):
|
||||
sess.run(out, {a: [10, 20, 30], index: [0, 0]})
|
||||
|
||||
def testReduceEuclideanNorm(self, index_dtype):
|
||||
|
||||
def reference_euclidean_norm(dtype, inp, axis):
|
||||
inp = inp.astype(dtype)
|
||||
return np.sqrt(np.sum(inp * np.conj(inp), axis)).astype(dtype)
|
||||
|
||||
for real_dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
|
||||
self._testReduction(
|
||||
math_ops.reduce_euclidean_norm,
|
||||
functools.partial(reference_euclidean_norm, real_dtype), real_dtype,
|
||||
self.REAL_DATA, index_dtype)
|
||||
|
||||
for complex_dtype in [np.complex64]:
|
||||
self._testReduction(
|
||||
math_ops.reduce_euclidean_norm,
|
||||
functools.partial(reference_euclidean_norm, complex_dtype),
|
||||
complex_dtype, self.COMPLEX_DATA, index_dtype)
|
||||
|
||||
|
||||
class ReduceOpPrecisionTest(xla_test.XLATestCase):
|
||||
|
||||
|
@ -16,12 +16,15 @@ limitations under the License.
|
||||
// XLA-specific reduction Ops.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -184,5 +187,44 @@ class AnyOp : public XlaReductionOp {
|
||||
REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"),
|
||||
AnyOp);
|
||||
|
||||
class EuclideanNormOp : public XlaReductionOp {
|
||||
public:
|
||||
explicit EuclideanNormOp(OpKernelConstruction* ctx)
|
||||
: XlaReductionOp(ctx,
|
||||
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
|
||||
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
|
||||
return xla::Zero(builder, xla_reduction_type_);
|
||||
}
|
||||
|
||||
xla::XlaOp PreprocessInput(xla::XlaBuilder* /*builder*/,
|
||||
const xla::XlaOp& data) override {
|
||||
return xla::Mul(data, MaybeConjugate(data, true));
|
||||
}
|
||||
|
||||
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
|
||||
const xla::XlaOp& scalar_rhs) override {
|
||||
xla::Add(scalar_lhs, scalar_rhs);
|
||||
}
|
||||
|
||||
xla::XlaOp BuildFinalizer(
|
||||
xla::XlaBuilder* /*builder*/, const xla::XlaOp& input,
|
||||
const xla::XlaOp& reduce_output,
|
||||
const std::vector<int64>& dimensions_to_reduce) override {
|
||||
if (xla::primitive_util::IsIntegralType(xla_reduction_type_)) {
|
||||
// XLA only supports float and complex sqrt.
|
||||
// Thus, cast integral type to F32 for computation.
|
||||
return XlaHelpers::ConvertElementType(
|
||||
xla::Sqrt(xla::ConvertElementType(reduce_output, xla::F32)),
|
||||
input_type(0));
|
||||
}
|
||||
return XlaHelpers::ConvertElementType(xla::Sqrt(reduce_output),
|
||||
input_type(0));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(
|
||||
Name("EuclideanNorm").CompileTimeConstantInput("reduction_indices"),
|
||||
EuclideanNormOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -39,6 +39,10 @@ class XlaReductionOp : public XlaOpKernel {
|
||||
// Return the base case for the reduction.
|
||||
virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
|
||||
|
||||
// Preprocesses input before reduction.
|
||||
virtual xla::XlaOp PreprocessInput(xla::XlaBuilder* builder,
|
||||
const xla::XlaOp& data);
|
||||
|
||||
// Implement the (scalar,scalar)->scalar lambda that should be
|
||||
// applied to each pair of elements to be reduced. The desired
|
||||
// computation should be added to 'builder' and
|
||||
|
@ -35,6 +35,12 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
|
||||
ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
|
||||
}
|
||||
|
||||
// The default pre-processor directly returns the data. This can be overridden.
|
||||
xla::XlaOp XlaReductionOp::PreprocessInput(xla::XlaBuilder* /*builder*/,
|
||||
const xla::XlaOp& data) {
|
||||
return data;
|
||||
}
|
||||
|
||||
// The default finalizer converts the results back into the input type. This can
|
||||
// be overridden.
|
||||
xla::XlaOp XlaReductionOp::BuildFinalizer(
|
||||
@ -111,7 +117,8 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
|
||||
|
||||
auto data = xla::ConvertElementType(ctx->Input(0), type);
|
||||
auto converted_input = xla::ConvertElementType(ctx->Input(0), type);
|
||||
auto processed_input = PreprocessInput(b, converted_input);
|
||||
// Call virtual method to get the initial value.
|
||||
auto initial = xla::ConvertElementType(InitialValue(b), type);
|
||||
// Make two scalar parameters of the desired type for the lambda.
|
||||
@ -121,8 +128,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
||||
BuildReducer(&r, rx, ry);
|
||||
xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
|
||||
|
||||
auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes);
|
||||
auto finalized = BuildFinalizer(b, data, reduce, xla_axes);
|
||||
auto reduce =
|
||||
xla::Reduce(processed_input, initial, reduction_computation, xla_axes);
|
||||
auto finalized = BuildFinalizer(b, converted_input, reduce, xla_axes);
|
||||
auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user