[XLA:CPU] Fuse non-vectorized reduces and remove argmax customcall

Fusing the variadic reduce version of argmax is not slower than the custom
implementation on our benchmarks, so just do that instead. Saves a bit of
space for everyone not using argmax.

We could also fuse into the vectorized code, but that's a bit more involved
and I'm not sure yet whether I see enough speedup to warrant that extra
complexity.

PiperOrigin-RevId: 254187396
This commit is contained in:
Benjamin Kramer 2019-06-20 06:51:42 -07:00 committed by TensorFlower Gardener
parent 9385325594
commit af4eb9c864
12 changed files with 72 additions and 304 deletions

View File

@ -34,7 +34,6 @@ cc_library(
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:cpu_function_runtime",

View File

@ -286,8 +286,6 @@ def tf_library(
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually
# needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",

View File

@ -51,7 +51,6 @@ cc_library(
deps = [
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",

View File

@ -40,7 +40,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -108,7 +107,6 @@ cc_library(
":tf2xla_proto",
":tf2xla_util",
":xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:xla_computation",

View File

@ -345,50 +345,3 @@ tf_kernel_library(
],
alwayslink = 1,
)
# Kernels that only work on CPU, because they use XLA custom calls.
# Only link this when using the CPU backend for XLA.
tf_kernel_library(
name = "xla_cpu_only_ops",
srcs = ["index_ops_cpu.cc"],
deps = [
":index_ops_kernel_argmax_float_1d",
":index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core:lib",
],
)
cc_library(
name = "index_ops_kernel_argmax_float_1d",
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
alwayslink = 1,
)
cc_library(
name = "index_ops_kernel_argmax_float_2d",
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
alwayslink = 1,
)

View File

@ -1,142 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Native XLA implementations of indexing ops.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
// The logic below uses a custom-call to implement argmax when possible. When
// custom-call is not allowed or input shapes are not supported, this kernel
// falls back to using XLA HLO native ArgMax.
//
// Also see b/29507024 for first-class XLA support for indexing ops.
class ArgMaxCustomCallOp : public XlaOpKernel {
public:
explicit ArgMaxCustomCallOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
const TensorShape dimension_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape),
errors::InvalidArgument(
"dim must be a scalar, but received tensor of shape: ",
dimension_shape.DebugString()));
// We require that the dimension argument is a constant, since it lets us
// dispatch to a specialized custom-call function without any run-time
// overhead, when compiling ahead-of-time.
int64 dim;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim));
const int input_dims = input_shape.dims();
const int axis = dim < 0 ? dim + input_dims : dim;
OP_REQUIRES(ctx, axis >= 0 && axis < input_dims,
errors::InvalidArgument("Expected dimension in the range [",
-input_dims, ", ", input_dims,
"), but got ", dim));
const int64 axis_size = input_shape.dim_size(axis);
OP_REQUIRES(ctx, axis_size > 0,
errors::InvalidArgument(
"Reduction axis ", dim,
" is empty in shape: ", input_shape.DebugString()));
const DataType dtype = output_type(0);
xla::PrimitiveType output_type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type));
// Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input
// shape isn't supported.
if (!ctx->compiler()->options().allow_cpu_custom_calls ||
(input_dims != 1 && input_dims != 2)) {
xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis);
ctx->SetOutput(0, output);
return;
}
xla::XlaOp output;
// The output shape is the input shape contracted along axis.
TensorShape output_shape;
for (int d = 0; d < input_shape.dims() - 1; ++d) {
output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
}
xla::XlaBuilder& b = *ctx->builder();
// XLA passes <out> to the function, so it is not included here.
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
&b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
&b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(axis)));
}
// The argmax function expects row-major layout.
xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
xla::S64, output_shape.dim_sizes());
std::vector<xla::Shape> arg_shapes;
for (const xla::XlaOp& arg : args) {
auto shape_status = b.GetShape(arg);
OP_REQUIRES_OK(ctx, shape_status.status());
xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
*arg_shape.mutable_layout() =
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
arg_shapes.push_back(std::move(arg_shape));
}
// Tell XLA to call the custom code, defined in
// index_ops_kernel_argmax_float_{1, 2}d.cc.
if (input_dims == 1) {
output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
xla_shape, arg_shapes);
} else {
output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
xla_shape, arg_shapes);
}
output = xla::ConvertElementType(output, output_type);
ctx->SetOutput(0, output);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp);
};
REGISTER_XLA_OP(Name("ArgMax")
.TypeConstraint("T", DT_FLOAT)
.Device(DEVICE_CPU_XLA_JIT)
.CompileTimeConstantInput("dimension"),
ArgMaxCustomCallOp);
} // namespace
} // namespace tensorflow

View File

@ -1,52 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) {
// Data is managed by the JIT code so msan can't tell it's initialized.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*));
float* input = static_cast<float*>(data[0]);
int64 input_size = *static_cast<int64*>(data[1]);
Eigen::DSizes<Eigen::DenseIndex, 1> in_eig_sizes(input_size);
TTypes<float, 1>::ConstTensor in_eig(input, in_eig_sizes);
Eigen::DSizes<Eigen::DenseIndex, 0> out_eig_sizes;
int64* out_t = static_cast<int64*>(out);
TTypes<int64, 0>::Tensor out_eig(out_t, out_eig_sizes);
out_eig = in_eig.argmax(0).cast<int64>();
}
} // namespace tensorflow
// Implements argmax on CPU. This is called by an XLA custom call, set up by
// index_ops.cc.
extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_1d_xla_impl(out, data);
}
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);

View File

@ -1,54 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) {
// data is managed by the JIT code so msan can't tell it's initialized.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 4 * sizeof(void*));
float* in = static_cast<float*>(data[0]);
int64* in_sizes = static_cast<int64*>(data[1]);
int64* out_sizes = static_cast<int64*>(data[2]);
int32 dim = *static_cast<int32*>(data[3]);
Eigen::DSizes<Eigen::DenseIndex, 2> in_eig_sizes(in_sizes[0], in_sizes[1]);
TTypes<float, 2>::ConstTensor in_eig(in, in_eig_sizes);
int64* out_t = static_cast<int64*>(out);
Eigen::DSizes<Eigen::DenseIndex, 1> out_eig_sizes(out_sizes[0]);
TTypes<int64, 1>::Tensor out_eig(out_t, out_eig_sizes);
out_eig = in_eig.argmax(dim).cast<int64>();
}
} // namespace tensorflow
// Implements argmax on CPU. This is called by an XLA custom call, set up by
// index_ops.cc.
extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_2d_xla_impl(out, data);
}
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);

View File

@ -33,6 +33,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
hlo.opcode() == HloOpcode::kGather ||
hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kReshape ||
hlo.opcode() == HloOpcode::kReverse ||
hlo.opcode() == HloOpcode::kSlice ||
@ -151,6 +152,19 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
}
}
// Don't fuse reductions over the major dimensions. These have an efficient
// lowering that's only implemented for the unfused case.
if (consumer->opcode() == HloOpcode::kReduce) {
return absl::c_linear_search(
consumer->dimensions(),
LayoutUtil::Minor(consumer->operand(0)->shape().layout(), 0));
}
if (producer->opcode() == HloOpcode::kReduce) {
return absl::c_linear_search(
producer->dimensions(),
LayoutUtil::Minor(producer->operand(0)->shape().layout(), 0));
}
if (consumer->IsLoopFusion()) {
VLOG(2) << "Fusing: consumer is a fusion node.";
return true;

View File

@ -889,6 +889,61 @@ INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation,
GatherLoopFusionTest,
::testing::ValuesIn(GetGatherLoopFusionTestSpecs()),
GatherLoopFusionTestSpec::Name);
TEST_F(InstructionFusionTest, NoFuseReduceMajor) {
absl::string_view module_string = R"(
HloModule module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY main {
a = f32[50,60]{1,0} parameter(0)
b = f32[50,60]{1,0} parameter(1)
c = f32[50,60]{1,0} add(a, b)
init = f32[] constant(0)
ROOT r = f32[60]{0} reduce(c, init), dimensions={0}, to_apply=add
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
CpuInstructionFusion().Run(module.get()));
EXPECT_FALSE(fused_something);
EXPECT_THAT(module->entry_computation()->root_instruction(),
Not(op::Fusion()));
}
TEST_F(InstructionFusionTest, FuseReduceMinor) {
absl::string_view module_string = R"(
HloModule module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY main {
a = f32[50,60]{1,0} parameter(0)
b = f32[50,60]{1,0} parameter(1)
c = f32[50,60]{1,0} add(a, b)
init = f32[] constant(0)
ROOT r = f32[] reduce(c, init), dimensions={0,1}, to_apply=add
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
CpuInstructionFusion().Run(module.get()));
EXPECT_TRUE(fused_something);
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion());
}
} // namespace
} // namespace cpu
} // namespace xla

View File

@ -3125,7 +3125,7 @@ Status IrEmitter::EmitTargetElementLoop(
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() ||
if (target_shape.IsTuple() && (target_op->opcode() == HloOpcode::kFusion ||
target_op->opcode() == HloOpcode::kReduce)) {
// For multiple outputs fusion, we need to emit each operand and the root.
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);

View File

@ -169,11 +169,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
cshape,
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {6, 1}), concatenate)),
ShapeUtil::MakeShape(F32, {1, 6}), concatenate)),
/*init_value=*/
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{1}, add_f32));
/*dimensions_to_reduce=*/{0}, add_f32));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));