Merge pull request #41916 from WindQAQ:tf2xla-euclidean-norm
PiperOrigin-RevId: 326304916 Change-Id: I1dcb9303e78c0f198de82cb1f0db9734469c7565
This commit is contained in:
commit
d487959ffc
@ -1892,6 +1892,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"Einsum",
|
"Einsum",
|
||||||
"EmptyTensorList",
|
"EmptyTensorList",
|
||||||
"EnsureShape",
|
"EnsureShape",
|
||||||
|
"EuclideanNorm",
|
||||||
"ExtractImagePatches",
|
"ExtractImagePatches",
|
||||||
"Igamma",
|
"Igamma",
|
||||||
"IgammaGradA",
|
"IgammaGradA",
|
||||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -50,7 +51,8 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
a = array_ops.placeholder(dtype)
|
a = array_ops.placeholder(dtype)
|
||||||
index = array_ops.placeholder(index_dtype)
|
index = array_ops.placeholder(index_dtype)
|
||||||
out = tf_reduce_fn(a, index)
|
out = def_function.function(experimental_compile=True)(tf_reduce_fn)(
|
||||||
|
a, index)
|
||||||
result = sess.run(out, {a: test_input, index: [0]})
|
result = sess.run(out, {a: test_input, index: [0]})
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol)
|
result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol)
|
||||||
@ -179,6 +181,24 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
'Axes contains duplicate dimension'):
|
'Axes contains duplicate dimension'):
|
||||||
sess.run(out, {a: [10, 20, 30], index: [0, 0]})
|
sess.run(out, {a: [10, 20, 30], index: [0, 0]})
|
||||||
|
|
||||||
|
def testReduceEuclideanNorm(self, index_dtype):
|
||||||
|
|
||||||
|
def reference_euclidean_norm(dtype, inp, axis):
|
||||||
|
inp = inp.astype(dtype)
|
||||||
|
return np.sqrt(np.sum(inp * np.conj(inp), axis)).astype(dtype)
|
||||||
|
|
||||||
|
for real_dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
|
||||||
|
self._testReduction(
|
||||||
|
math_ops.reduce_euclidean_norm,
|
||||||
|
functools.partial(reference_euclidean_norm, real_dtype), real_dtype,
|
||||||
|
self.REAL_DATA, index_dtype)
|
||||||
|
|
||||||
|
for complex_dtype in [np.complex64]:
|
||||||
|
self._testReduction(
|
||||||
|
math_ops.reduce_euclidean_norm,
|
||||||
|
functools.partial(reference_euclidean_norm, complex_dtype),
|
||||||
|
complex_dtype, self.COMPLEX_DATA, index_dtype)
|
||||||
|
|
||||||
|
|
||||||
class ReduceOpPrecisionTest(xla_test.XLATestCase):
|
class ReduceOpPrecisionTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
@ -16,12 +16,15 @@ limitations under the License.
|
|||||||
// XLA-specific reduction Ops.
|
// XLA-specific reduction Ops.
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
|
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -184,5 +187,44 @@ class AnyOp : public XlaReductionOp {
|
|||||||
REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"),
|
REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"),
|
||||||
AnyOp);
|
AnyOp);
|
||||||
|
|
||||||
|
class EuclideanNormOp : public XlaReductionOp {
|
||||||
|
public:
|
||||||
|
explicit EuclideanNormOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaReductionOp(ctx,
|
||||||
|
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
|
||||||
|
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
|
||||||
|
return xla::Zero(builder, xla_reduction_type_);
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::XlaOp PreprocessInput(xla::XlaBuilder* /*builder*/,
|
||||||
|
const xla::XlaOp& data) override {
|
||||||
|
return xla::Mul(data, MaybeConjugate(data, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
|
||||||
|
const xla::XlaOp& scalar_rhs) override {
|
||||||
|
xla::Add(scalar_lhs, scalar_rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::XlaOp BuildFinalizer(
|
||||||
|
xla::XlaBuilder* /*builder*/, const xla::XlaOp& input,
|
||||||
|
const xla::XlaOp& reduce_output,
|
||||||
|
const std::vector<int64>& dimensions_to_reduce) override {
|
||||||
|
if (xla::primitive_util::IsIntegralType(xla_reduction_type_)) {
|
||||||
|
// XLA only supports float and complex sqrt.
|
||||||
|
// Thus, cast integral type to F32 for computation.
|
||||||
|
return XlaHelpers::ConvertElementType(
|
||||||
|
xla::Sqrt(xla::ConvertElementType(reduce_output, xla::F32)),
|
||||||
|
input_type(0));
|
||||||
|
}
|
||||||
|
return XlaHelpers::ConvertElementType(xla::Sqrt(reduce_output),
|
||||||
|
input_type(0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(
|
||||||
|
Name("EuclideanNorm").CompileTimeConstantInput("reduction_indices"),
|
||||||
|
EuclideanNormOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -39,6 +39,10 @@ class XlaReductionOp : public XlaOpKernel {
|
|||||||
// Return the base case for the reduction.
|
// Return the base case for the reduction.
|
||||||
virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
|
virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
|
||||||
|
|
||||||
|
// Preprocesses input before reduction.
|
||||||
|
virtual xla::XlaOp PreprocessInput(xla::XlaBuilder* builder,
|
||||||
|
const xla::XlaOp& data);
|
||||||
|
|
||||||
// Implement the (scalar,scalar)->scalar lambda that should be
|
// Implement the (scalar,scalar)->scalar lambda that should be
|
||||||
// applied to each pair of elements to be reduced. The desired
|
// applied to each pair of elements to be reduced. The desired
|
||||||
// computation should be added to 'builder' and
|
// computation should be added to 'builder' and
|
||||||
|
@ -35,6 +35,12 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
|
|||||||
ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
|
ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The default pre-processor directly returns the data. This can be overridden.
|
||||||
|
xla::XlaOp XlaReductionOp::PreprocessInput(xla::XlaBuilder* /*builder*/,
|
||||||
|
const xla::XlaOp& data) {
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
// The default finalizer converts the results back into the input type. This can
|
// The default finalizer converts the results back into the input type. This can
|
||||||
// be overridden.
|
// be overridden.
|
||||||
xla::XlaOp XlaReductionOp::BuildFinalizer(
|
xla::XlaOp XlaReductionOp::BuildFinalizer(
|
||||||
@ -111,7 +117,8 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
xla::PrimitiveType type;
|
xla::PrimitiveType type;
|
||||||
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
|
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
|
||||||
|
|
||||||
auto data = xla::ConvertElementType(ctx->Input(0), type);
|
auto converted_input = xla::ConvertElementType(ctx->Input(0), type);
|
||||||
|
auto processed_input = PreprocessInput(b, converted_input);
|
||||||
// Call virtual method to get the initial value.
|
// Call virtual method to get the initial value.
|
||||||
auto initial = xla::ConvertElementType(InitialValue(b), type);
|
auto initial = xla::ConvertElementType(InitialValue(b), type);
|
||||||
// Make two scalar parameters of the desired type for the lambda.
|
// Make two scalar parameters of the desired type for the lambda.
|
||||||
@ -121,8 +128,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
BuildReducer(&r, rx, ry);
|
BuildReducer(&r, rx, ry);
|
||||||
xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
|
xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
|
||||||
|
|
||||||
auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes);
|
auto reduce =
|
||||||
auto finalized = BuildFinalizer(b, data, reduce, xla_axes);
|
xla::Reduce(processed_input, initial, reduction_computation, xla_axes);
|
||||||
|
auto finalized = BuildFinalizer(b, converted_input, reduce, xla_axes);
|
||||||
auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
|
auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
|
||||||
ctx->SetOutput(0, result);
|
ctx->SetOutput(0, result);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user