Add EuclideanNorm kernel

Update allowlist

Update allowlist

Change ProcessData to ProcessInput

Change ProcessInput to PreprocessInput

Pass input to finalizer

Change op name
This commit is contained in:
Tzu-Wei Sung 2020-07-30 17:27:50 -07:00
parent 8145abe98c
commit 4e9d0b23aa
5 changed files with 81 additions and 4 deletions

View File

@ -1892,6 +1892,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"Einsum", "Einsum",
"EmptyTensorList", "EmptyTensorList",
"EnsureShape", "EnsureShape",
"EuclideanNorm",
"ExtractImagePatches", "ExtractImagePatches",
"Igamma", "Igamma",
"IgammaGradA", "IgammaGradA",

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -50,7 +51,8 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
with self.test_scope(): with self.test_scope():
a = array_ops.placeholder(dtype) a = array_ops.placeholder(dtype)
index = array_ops.placeholder(index_dtype) index = array_ops.placeholder(index_dtype)
out = tf_reduce_fn(a, index) out = def_function.function(experimental_compile=True)(
tf_reduce_fn)(a, index)
result = sess.run(out, {a: test_input, index: [0]}) result = sess.run(out, {a: test_input, index: [0]})
self.assertAllClose( self.assertAllClose(
result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol) result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol)
@ -179,6 +181,26 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
'Axes contains duplicate dimension'): 'Axes contains duplicate dimension'):
sess.run(out, {a: [10, 20, 30], index: [0, 0]}) sess.run(out, {a: [10, 20, 30], index: [0, 0]})
def testReduceEuclideanNorm(self, index_dtype):
def reference_euclidean_norm(dtype, inp, axis):
inp = inp.astype(dtype)
return np.sqrt(np.sum(inp * np.conj(inp), axis)).astype(dtype)
for real_dtype in [np.int32, np.int64, np.float16,
np.float32, np.float64]:
self._testReduction(math_ops.reduce_euclidean_norm,
functools.partial(
reference_euclidean_norm, real_dtype),
real_dtype,
self.REAL_DATA, index_dtype)
for complex_dtype in [np.complex64, np.complex128]:
self._testReduction(math_ops.reduce_euclidean_norm,
functools.partial(
reference_euclidean_norm, complex_dtype),
complex_dtype,
self.COMPLEX_DATA, index_dtype)
class ReduceOpPrecisionTest(xla_test.XLATestCase): class ReduceOpPrecisionTest(xla_test.XLATestCase):

View File

@ -16,12 +16,15 @@ limitations under the License.
// XLA-specific reduction Ops. // XLA-specific reduction Ops.
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow { namespace tensorflow {
@ -184,5 +187,44 @@ class AnyOp : public XlaReductionOp {
REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"), REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"),
AnyOp); AnyOp);
class EuclideanNormOp : public XlaReductionOp {
public:
explicit EuclideanNormOp(OpKernelConstruction* ctx)
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
return xla::Zero(builder, xla_reduction_type_);
}
xla::XlaOp PreprocessInput(xla::XlaBuilder* /*builder*/,
const xla::XlaOp& data) override {
return xla::Mul(data, MaybeConjugate(data, true));
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
xla::Add(scalar_lhs, scalar_rhs);
}
xla::XlaOp BuildFinalizer(
xla::XlaBuilder* /*builder*/, const xla::XlaOp& input,
const xla::XlaOp& reduce_output,
const std::vector<int64>& dimensions_to_reduce) override {
if (xla::primitive_util::IsIntegralType(xla_reduction_type_)) {
// XLA only supports float and complex sqrt.
// Thus, cast integral type to F32 for computation.
return XlaHelpers::ConvertElementType(
xla::Sqrt(xla::ConvertElementType(reduce_output, xla::F32)),
input_type(0));
}
return XlaHelpers::ConvertElementType(xla::Sqrt(reduce_output),
input_type(0));
}
};
REGISTER_XLA_OP(
Name("EuclideanNorm").CompileTimeConstantInput("reduction_indices"),
EuclideanNormOp);
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -39,6 +39,10 @@ class XlaReductionOp : public XlaOpKernel {
// Return the base case for the reduction. // Return the base case for the reduction.
virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
// Preprocesses input before reduction.
virtual xla::XlaOp PreprocessInput(xla::XlaBuilder* builder,
const xla::XlaOp& data);
// Implement the (scalar,scalar)->scalar lambda that should be // Implement the (scalar,scalar)->scalar lambda that should be
// applied to each pair of elements to be reduced. The desired // applied to each pair of elements to be reduced. The desired
// computation should be added to 'builder' and // computation should be added to 'builder' and

View File

@ -35,6 +35,12 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
} }
// The default pre-processor directly returns the data. This can be overridden.
xla::XlaOp XlaReductionOp::PreprocessInput(xla::XlaBuilder* /*builder*/,
const xla::XlaOp& data) {
return data;
}
// The default finalizer converts the results back into the input type. This can // The default finalizer converts the results back into the input type. This can
// be overridden. // be overridden.
xla::XlaOp XlaReductionOp::BuildFinalizer( xla::XlaOp XlaReductionOp::BuildFinalizer(
@ -111,7 +117,8 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
xla::PrimitiveType type; xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
auto data = xla::ConvertElementType(ctx->Input(0), type); auto converted_input = xla::ConvertElementType(ctx->Input(0), type);
auto processed_input = PreprocessInput(b, converted_input);
// Call virtual method to get the initial value. // Call virtual method to get the initial value.
auto initial = xla::ConvertElementType(InitialValue(b), type); auto initial = xla::ConvertElementType(InitialValue(b), type);
// Make two scalar parameters of the desired type for the lambda. // Make two scalar parameters of the desired type for the lambda.
@ -121,8 +128,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
BuildReducer(&r, rx, ry); BuildReducer(&r, rx, ry);
xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); auto reduce =
auto finalized = BuildFinalizer(b, data, reduce, xla_axes); xla::Reduce(processed_input, initial, reduction_computation, xla_axes);
auto finalized = BuildFinalizer(b, converted_input, reduce, xla_axes);
auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
ctx->SetOutput(0, result); ctx->SetOutput(0, result);
} }