[XLA] Move ArgMin/ArgMax out of TF/XLA and into XLA client library. NFC intended.

PiperOrigin-RevId: 228373518
This commit is contained in:
Peter Hawkins 2019-01-08 12:01:56 -08:00 committed by TensorFlower Gardener
parent a849894f02
commit 68727289bc
9 changed files with 156 additions and 84 deletions

View File

@ -99,8 +99,8 @@ class CategoricalOp : public XlaOpKernel {
xla::PrimitiveType xla_output_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(output_type(0), &xla_output_type));
xla::XlaOp argmax = XlaHelpers::ArgMax(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
if (num_samples == 1) {
argmax = xla::Reshape(argmax, {batch_size, 1});
}

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/index_ops.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
@ -66,9 +65,9 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaOp input = ctx->Input(0);
xla::XlaOp output;
if (is_min_) {
output = XlaHelpers::ArgMin(input, index_xla_type, axis);
output = xla::ArgMin(input, index_xla_type, axis);
} else {
output = XlaHelpers::ArgMax(input, index_xla_type, axis);
output = xla::ArgMax(input, index_xla_type, axis);
}
ctx->SetOutput(0, output);

View File

@ -16,9 +16,9 @@ limitations under the License.
// Native XLA implementations of indexing ops.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -74,7 +74,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
// shape isn't supported.
if (!ctx->compiler()->options().allow_cpu_custom_calls ||
(input_dims != 1 && input_dims != 2)) {
xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis);
xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis);
ctx->SetOutput(0, output);
return;
}

View File

@ -34,63 +34,6 @@ limitations under the License.
namespace tensorflow {
namespace {
xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis,
bool is_min) {
xla::XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
xla::XlaOp init_value;
xla::XlaComputation reducer;
if (is_min) {
init_value = xla::MaxValue(builder, input_shape.element_type());
reducer =
xla::CreateScalarMinComputation(input_shape.element_type(), builder);
} else {
init_value = xla::MinValue(builder, input_shape.element_type());
reducer =
xla::CreateScalarMaxComputation(input_shape.element_type(), builder);
}
xla::XlaOp input_max = xla::Reduce(input, init_value, reducer,
/*dimensions_to_reduce=*/{axis});
std::vector<int64> broadcast_dims(input_shape.rank() - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
// Compute a mask that has 1s for elements equal to the maximum.
xla::XlaOp partial_mask = xla::ConvertElementType(
xla::Eq(input, input_max, broadcast_dims), output_type);
// In order to make identity elements for a bitwise And, we:
// Left shift the 1 to the leftmost bit, yielding 0x10...0
// Arithmetic right shift the 1 back to the rightmost bit, yielding
// 0xFF...F
int32 bits_in_type =
xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
xla::XlaOp shift_amount =
xla::ConstantR0WithType(builder, output_type, bits_in_type);
xla::XlaOp full_mask = xla::ShiftRightArithmetic(
xla::ShiftLeft(partial_mask, shift_amount), shift_amount);
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
// index.
const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis);
xla::XlaOp iota = xla::Iota(builder, output_type, axis_size);
xla::XlaOp product =
xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis});
// If there are multiple maximum elements, choose the one with the highest
// index.
return xla::Reduce(product, xla::MinValue(builder, output_type),
xla::CreateScalarMaxComputation(output_type, builder),
/*dimensions_to_reduce=*/{axis});
});
}
} // namespace
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
@ -148,16 +91,6 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
return linspace;
}
xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
}
xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
}
Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
DataType index_type, const TensorShape& indices_shape,
const xla::XlaOp& indices, const xla::XlaOp& on_value,

View File

@ -53,16 +53,6 @@ class XlaHelpers {
absl::Span<const int64> shape,
xla::Literal* output);
// Returns the argmax of `input` along `axis`. `output_type` is the type to
// use for the output.
static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
int axis);
// Returns the argmin of `input` along `axis`. `output_type` is the type to
// use for the output.
static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
int axis);
// Converts `indices` into a one-hot representation. `depth` is the size
// of the new axis to add. `axis` is the position at which to add the new
// axis. `indices_shape` is the shape of `indices`. `on_value` and

View File

@ -34,6 +34,21 @@ cc_library(
],
)
xla_test(
name = "arithmetic_test",
srcs = ["arithmetic_test.cc"],
deps = [
":arithmetic",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],

View File

@ -123,4 +123,64 @@ XlaOp Any(XlaOp predicates) {
});
}
namespace {
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
XlaOp init_value;
XlaComputation reducer;
if (is_min) {
init_value = MaxValue(builder, input_shape.element_type());
reducer = CreateScalarMinComputation(input_shape.element_type(), builder);
} else {
init_value = MinValue(builder, input_shape.element_type());
reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
}
XlaOp input_max = Reduce(input, init_value, reducer,
/*dimensions_to_reduce=*/{axis});
std::vector<int64> broadcast_dims(input_shape.rank() - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
// Compute a mask that has 1s for elements equal to the maximum.
XlaOp partial_mask =
ConvertElementType(Eq(input, input_max, broadcast_dims), output_type);
// In order to make identity elements for a bitwise And, we:
// Left shift the 1 to the leftmost bit, yielding 0x10...0
// Arithmetic right shift the 1 back to the rightmost bit, yielding
// 0xFF...F
int32 bits_in_type =
ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type);
XlaOp full_mask = ShiftRightArithmetic(
ShiftLeft(partial_mask, shift_amount), shift_amount);
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
// index.
const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis);
XlaOp iota = Iota(builder, output_type, axis_size);
XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis});
// If there are multiple maximum elements, choose the one with the highest
// index.
return Reduce(product, MinValue(builder, output_type),
CreateScalarMaxComputation(output_type, builder),
/*dimensions_to_reduce=*/{axis});
});
}
} // namespace
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
}
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
}
} // namespace xla

View File

@ -57,6 +57,14 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type,
// Note: if predicates is zero-sized, Any() vacuously returns false.
XlaOp Any(XlaOp predicates);
// Returns the argmax of `input` along `axis`. `output_type` is the type to
// use for the output.
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
// Returns the argmin of `input` along `axis`. `output_type` is the type to
// use for the output.
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
using ArithmeticTest = ClientLibraryTestBase;
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) {
XlaBuilder builder(TestName());
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
ArgMin(x, S32, /*axis=*/0);
std::vector<int32> expected = {0, 2, 2};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) {
XlaBuilder builder(TestName());
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
ArgMin(x, S32, /*axis=*/1);
std::vector<int32> expected = {0, 1, 2};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) {
XlaBuilder builder(TestName());
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
ArgMax(x, S32, /*axis=*/0);
std::vector<int32> expected = {2, 0, 1};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) {
XlaBuilder builder(TestName());
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
ArgMax(x, S32, /*axis=*/1);
std::vector<int32> expected = {1, 0, 0};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
} // namespace
} // namespace xla