[XLA] Move ArgMin/ArgMax out of TF/XLA and into XLA client library. NFC intended.
PiperOrigin-RevId: 228373518
This commit is contained in:
parent
a849894f02
commit
68727289bc
@ -99,8 +99,8 @@ class CategoricalOp : public XlaOpKernel {
|
||||
xla::PrimitiveType xla_output_type;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
DataTypeToPrimitiveType(output_type(0), &xla_output_type));
|
||||
xla::XlaOp argmax = XlaHelpers::ArgMax(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
if (num_samples == 1) {
|
||||
argmax = xla::Reshape(argmax, {batch_size, 1});
|
||||
}
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/kernels/index_ops.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
@ -66,9 +65,9 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
xla::XlaOp output;
|
||||
if (is_min_) {
|
||||
output = XlaHelpers::ArgMin(input, index_xla_type, axis);
|
||||
output = xla::ArgMin(input, index_xla_type, axis);
|
||||
} else {
|
||||
output = XlaHelpers::ArgMax(input, index_xla_type, axis);
|
||||
output = xla::ArgMax(input, index_xla_type, axis);
|
||||
}
|
||||
|
||||
ctx->SetOutput(0, output);
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
// Native XLA implementations of indexing ops.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -74,7 +74,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
|
||||
// shape isn't supported.
|
||||
if (!ctx->compiler()->options().allow_cpu_custom_calls ||
|
||||
(input_dims != 1 && input_dims != 2)) {
|
||||
xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis);
|
||||
xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis);
|
||||
ctx->SetOutput(0, output);
|
||||
return;
|
||||
}
|
||||
|
@ -34,63 +34,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis,
|
||||
bool is_min) {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
|
||||
xla::XlaOp init_value;
|
||||
xla::XlaComputation reducer;
|
||||
if (is_min) {
|
||||
init_value = xla::MaxValue(builder, input_shape.element_type());
|
||||
reducer =
|
||||
xla::CreateScalarMinComputation(input_shape.element_type(), builder);
|
||||
} else {
|
||||
init_value = xla::MinValue(builder, input_shape.element_type());
|
||||
reducer =
|
||||
xla::CreateScalarMaxComputation(input_shape.element_type(), builder);
|
||||
}
|
||||
|
||||
xla::XlaOp input_max = xla::Reduce(input, init_value, reducer,
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
std::vector<int64> broadcast_dims(input_shape.rank() - 1);
|
||||
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
|
||||
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
|
||||
// Compute a mask that has 1s for elements equal to the maximum.
|
||||
xla::XlaOp partial_mask = xla::ConvertElementType(
|
||||
xla::Eq(input, input_max, broadcast_dims), output_type);
|
||||
|
||||
// In order to make identity elements for a bitwise And, we:
|
||||
// Left shift the 1 to the leftmost bit, yielding 0x10...0
|
||||
// Arithmetic right shift the 1 back to the rightmost bit, yielding
|
||||
// 0xFF...F
|
||||
int32 bits_in_type =
|
||||
xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
|
||||
xla::XlaOp shift_amount =
|
||||
xla::ConstantR0WithType(builder, output_type, bits_in_type);
|
||||
xla::XlaOp full_mask = xla::ShiftRightArithmetic(
|
||||
xla::ShiftLeft(partial_mask, shift_amount), shift_amount);
|
||||
|
||||
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
|
||||
// index.
|
||||
|
||||
const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis);
|
||||
xla::XlaOp iota = xla::Iota(builder, output_type, axis_size);
|
||||
xla::XlaOp product =
|
||||
xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis});
|
||||
|
||||
// If there are multiple maximum elements, choose the one with the highest
|
||||
// index.
|
||||
return xla::Reduce(product, xla::MinValue(builder, output_type),
|
||||
xla::CreateScalarMaxComputation(output_type, builder),
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
@ -148,16 +91,6 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
|
||||
return linspace;
|
||||
}
|
||||
|
||||
xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
|
||||
int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
|
||||
}
|
||||
|
||||
xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
|
||||
int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
|
||||
}
|
||||
|
||||
Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
|
||||
DataType index_type, const TensorShape& indices_shape,
|
||||
const xla::XlaOp& indices, const xla::XlaOp& on_value,
|
||||
|
@ -53,16 +53,6 @@ class XlaHelpers {
|
||||
absl::Span<const int64> shape,
|
||||
xla::Literal* output);
|
||||
|
||||
// Returns the argmax of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output.
|
||||
static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
|
||||
int axis);
|
||||
|
||||
// Returns the argmin of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output.
|
||||
static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
|
||||
int axis);
|
||||
|
||||
// Converts `indices` into a one-hot representation. `depth` is the size
|
||||
// of the new axis to add. `axis` is the position at which to add the new
|
||||
// axis. `indices_shape` is the shape of `indices`. `on_value` and
|
||||
|
@ -34,6 +34,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "arithmetic_test",
|
||||
srcs = ["arithmetic_test.cc"],
|
||||
deps = [
|
||||
":arithmetic",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cholesky",
|
||||
srcs = ["cholesky.cc"],
|
||||
|
@ -123,4 +123,64 @@ XlaOp Any(XlaOp predicates) {
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
XlaOp init_value;
|
||||
XlaComputation reducer;
|
||||
if (is_min) {
|
||||
init_value = MaxValue(builder, input_shape.element_type());
|
||||
reducer = CreateScalarMinComputation(input_shape.element_type(), builder);
|
||||
} else {
|
||||
init_value = MinValue(builder, input_shape.element_type());
|
||||
reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
|
||||
}
|
||||
|
||||
XlaOp input_max = Reduce(input, init_value, reducer,
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
std::vector<int64> broadcast_dims(input_shape.rank() - 1);
|
||||
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
|
||||
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
|
||||
// Compute a mask that has 1s for elements equal to the maximum.
|
||||
XlaOp partial_mask =
|
||||
ConvertElementType(Eq(input, input_max, broadcast_dims), output_type);
|
||||
|
||||
// In order to make identity elements for a bitwise And, we:
|
||||
// Left shift the 1 to the leftmost bit, yielding 0x10...0
|
||||
// Arithmetic right shift the 1 back to the rightmost bit, yielding
|
||||
// 0xFF...F
|
||||
int32 bits_in_type =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
|
||||
XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type);
|
||||
XlaOp full_mask = ShiftRightArithmetic(
|
||||
ShiftLeft(partial_mask, shift_amount), shift_amount);
|
||||
|
||||
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
|
||||
// index.
|
||||
|
||||
const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis);
|
||||
XlaOp iota = Iota(builder, output_type, axis_size);
|
||||
XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis});
|
||||
|
||||
// If there are multiple maximum elements, choose the one with the highest
|
||||
// index.
|
||||
return Reduce(product, MinValue(builder, output_type),
|
||||
CreateScalarMaxComputation(output_type, builder),
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
|
||||
}
|
||||
|
||||
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -57,6 +57,14 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type,
|
||||
// Note: if predicates is zero-sized, Any() vacuously returns false.
|
||||
XlaOp Any(XlaOp predicates);
|
||||
|
||||
// Returns the argmax of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output.
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
|
||||
|
||||
// Returns the argmin of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output.
|
||||
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
|
||||
|
67
tensorflow/compiler/xla/client/lib/arithmetic_test.cc
Normal file
67
tensorflow/compiler/xla/client/lib/arithmetic_test.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ArithmeticTest = ClientLibraryTestBase;
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
|
||||
ArgMin(x, S32, /*axis=*/0);
|
||||
|
||||
std::vector<int32> expected = {0, 2, 2};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
|
||||
ArgMin(x, S32, /*axis=*/1);
|
||||
|
||||
std::vector<int32> expected = {0, 1, 2};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
|
||||
ArgMax(x, S32, /*axis=*/0);
|
||||
|
||||
std::vector<int32> expected = {2, 0, 1};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
|
||||
ArgMax(x, S32, /*axis=*/1);
|
||||
|
||||
std::vector<int32> expected = {1, 0, 0};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user