[XLA:CPU] Fuse non-vectorized reduces and remove argmax customcall
Fusing the variadic reduce version of argmax is not slower than the custom implementation on our benchmarks, so just do that instead. Saves a bit of space for everyone not using argmax. We could also fuse into the vectorized code, but that's a bit more involved and I'm not sure yet whether I see enough speedup to warrant that extra complexity. PiperOrigin-RevId: 254187396
This commit is contained in:
parent
9385325594
commit
af4eb9c864
|
@ -34,7 +34,6 @@ cc_library(
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_proto",
|
"//tensorflow/compiler/tf2xla:tf2xla_proto",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||||
|
|
|
@ -286,8 +286,6 @@ def tf_library(
|
||||||
] or []) + (include_standard_runtime_deps and [
|
] or []) + (include_standard_runtime_deps and [
|
||||||
# TODO(cwhipkey): only depend on kernel code that the model actually
|
# TODO(cwhipkey): only depend on kernel code that the model actually
|
||||||
# needed.
|
# needed.
|
||||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
|
||||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||||
"//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
|
"//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
|
||||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||||
|
|
|
@ -51,7 +51,6 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||||
|
|
|
@ -40,7 +40,6 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_compiler",
|
":xla_compiler",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
|
@ -108,7 +107,6 @@ cc_library(
|
||||||
":tf2xla_proto",
|
":tf2xla_proto",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
":xla_compiler",
|
":xla_compiler",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/client",
|
"//tensorflow/compiler/xla/client",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
|
|
|
@ -345,50 +345,3 @@ tf_kernel_library(
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Kernels that only work on CPU, because they use XLA custom calls.
|
|
||||||
# Only link this when using the CPU backend for XLA.
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "xla_cpu_only_ops",
|
|
||||||
srcs = ["index_ops_cpu.cc"],
|
|
||||||
deps = [
|
|
||||||
":index_ops_kernel_argmax_float_1d",
|
|
||||||
":index_ops_kernel_argmax_float_2d",
|
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
|
||||||
"//tensorflow/compiler/xla:literal",
|
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:framework_bounds_check",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "index_ops_kernel_argmax_float_1d",
|
|
||||||
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
|
|
||||||
copts = tf_copts(),
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
|
||||||
"//tensorflow/core:framework_lite",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "index_ops_kernel_argmax_float_2d",
|
|
||||||
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
|
|
||||||
copts = tf_copts(),
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
|
||||||
"//tensorflow/core:framework_lite",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,142 +0,0 @@
|
||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Native XLA implementations of indexing ops.
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// The logic below uses a custom-call to implement argmax when possible. When
|
|
||||||
// custom-call is not allowed or input shapes are not supported, this kernel
|
|
||||||
// falls back to using XLA HLO native ArgMax.
|
|
||||||
//
|
|
||||||
// Also see b/29507024 for first-class XLA support for indexing ops.
|
|
||||||
class ArgMaxCustomCallOp : public XlaOpKernel {
|
|
||||||
public:
|
|
||||||
explicit ArgMaxCustomCallOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
|
||||||
const TensorShape input_shape = ctx->InputShape(0);
|
|
||||||
const TensorShape dimension_shape = ctx->InputShape(1);
|
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"dim must be a scalar, but received tensor of shape: ",
|
|
||||||
dimension_shape.DebugString()));
|
|
||||||
|
|
||||||
// We require that the dimension argument is a constant, since it lets us
|
|
||||||
// dispatch to a specialized custom-call function without any run-time
|
|
||||||
// overhead, when compiling ahead-of-time.
|
|
||||||
int64 dim;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim));
|
|
||||||
|
|
||||||
const int input_dims = input_shape.dims();
|
|
||||||
const int axis = dim < 0 ? dim + input_dims : dim;
|
|
||||||
OP_REQUIRES(ctx, axis >= 0 && axis < input_dims,
|
|
||||||
errors::InvalidArgument("Expected dimension in the range [",
|
|
||||||
-input_dims, ", ", input_dims,
|
|
||||||
"), but got ", dim));
|
|
||||||
|
|
||||||
const int64 axis_size = input_shape.dim_size(axis);
|
|
||||||
OP_REQUIRES(ctx, axis_size > 0,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Reduction axis ", dim,
|
|
||||||
" is empty in shape: ", input_shape.DebugString()));
|
|
||||||
|
|
||||||
const DataType dtype = output_type(0);
|
|
||||||
xla::PrimitiveType output_type;
|
|
||||||
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type));
|
|
||||||
|
|
||||||
// Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input
|
|
||||||
// shape isn't supported.
|
|
||||||
if (!ctx->compiler()->options().allow_cpu_custom_calls ||
|
|
||||||
(input_dims != 1 && input_dims != 2)) {
|
|
||||||
xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis);
|
|
||||||
ctx->SetOutput(0, output);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
xla::XlaOp output;
|
|
||||||
// The output shape is the input shape contracted along axis.
|
|
||||||
TensorShape output_shape;
|
|
||||||
for (int d = 0; d < input_shape.dims() - 1; ++d) {
|
|
||||||
output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
xla::XlaBuilder& b = *ctx->builder();
|
|
||||||
|
|
||||||
// XLA passes <out> to the function, so it is not included here.
|
|
||||||
std::vector<xla::XlaOp> args;
|
|
||||||
args.push_back(ctx->Input(0));
|
|
||||||
args.push_back(xla::ConstantLiteral(
|
|
||||||
&b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
|
|
||||||
if (input_shape.dims() > 1) {
|
|
||||||
// Don't bother passing the output shape and dim for the 1d case, since
|
|
||||||
// the shape is always a scalar and the dim is always 0.
|
|
||||||
args.push_back(xla::ConstantLiteral(
|
|
||||||
&b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
|
|
||||||
args.push_back(
|
|
||||||
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(axis)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// The argmax function expects row-major layout.
|
|
||||||
xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
|
|
||||||
xla::S64, output_shape.dim_sizes());
|
|
||||||
std::vector<xla::Shape> arg_shapes;
|
|
||||||
for (const xla::XlaOp& arg : args) {
|
|
||||||
auto shape_status = b.GetShape(arg);
|
|
||||||
OP_REQUIRES_OK(ctx, shape_status.status());
|
|
||||||
xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
|
|
||||||
*arg_shape.mutable_layout() =
|
|
||||||
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
|
|
||||||
arg_shapes.push_back(std::move(arg_shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tell XLA to call the custom code, defined in
|
|
||||||
// index_ops_kernel_argmax_float_{1, 2}d.cc.
|
|
||||||
if (input_dims == 1) {
|
|
||||||
output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
|
|
||||||
xla_shape, arg_shapes);
|
|
||||||
} else {
|
|
||||||
output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
|
|
||||||
xla_shape, arg_shapes);
|
|
||||||
}
|
|
||||||
output = xla::ConvertElementType(output, output_type);
|
|
||||||
ctx->SetOutput(0, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp);
|
|
||||||
};
|
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("ArgMax")
|
|
||||||
.TypeConstraint("T", DT_FLOAT)
|
|
||||||
.Device(DEVICE_CPU_XLA_JIT)
|
|
||||||
.CompileTimeConstantInput("dimension"),
|
|
||||||
ArgMaxCustomCallOp);
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace tensorflow
|
|
|
@ -1,52 +0,0 @@
|
||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
|
||||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) {
|
|
||||||
// Data is managed by the JIT code so msan can't tell it's initialized.
|
|
||||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*));
|
|
||||||
|
|
||||||
float* input = static_cast<float*>(data[0]);
|
|
||||||
int64 input_size = *static_cast<int64*>(data[1]);
|
|
||||||
|
|
||||||
Eigen::DSizes<Eigen::DenseIndex, 1> in_eig_sizes(input_size);
|
|
||||||
TTypes<float, 1>::ConstTensor in_eig(input, in_eig_sizes);
|
|
||||||
|
|
||||||
Eigen::DSizes<Eigen::DenseIndex, 0> out_eig_sizes;
|
|
||||||
int64* out_t = static_cast<int64*>(out);
|
|
||||||
TTypes<int64, 0>::Tensor out_eig(out_t, out_eig_sizes);
|
|
||||||
|
|
||||||
out_eig = in_eig.argmax(0).cast<int64>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
// Implements argmax on CPU. This is called by an XLA custom call, set up by
|
|
||||||
// index_ops.cc.
|
|
||||||
extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
|
|
||||||
tensorflow::argmax_float_1d_xla_impl(out, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
|
|
|
@ -1,54 +0,0 @@
|
||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
|
||||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) {
|
|
||||||
// data is managed by the JIT code so msan can't tell it's initialized.
|
|
||||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 4 * sizeof(void*));
|
|
||||||
|
|
||||||
float* in = static_cast<float*>(data[0]);
|
|
||||||
int64* in_sizes = static_cast<int64*>(data[1]);
|
|
||||||
int64* out_sizes = static_cast<int64*>(data[2]);
|
|
||||||
int32 dim = *static_cast<int32*>(data[3]);
|
|
||||||
|
|
||||||
Eigen::DSizes<Eigen::DenseIndex, 2> in_eig_sizes(in_sizes[0], in_sizes[1]);
|
|
||||||
TTypes<float, 2>::ConstTensor in_eig(in, in_eig_sizes);
|
|
||||||
|
|
||||||
int64* out_t = static_cast<int64*>(out);
|
|
||||||
Eigen::DSizes<Eigen::DenseIndex, 1> out_eig_sizes(out_sizes[0]);
|
|
||||||
TTypes<int64, 1>::Tensor out_eig(out_t, out_eig_sizes);
|
|
||||||
|
|
||||||
out_eig = in_eig.argmax(dim).cast<int64>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
// Implements argmax on CPU. This is called by an XLA custom call, set up by
|
|
||||||
// index_ops.cc.
|
|
||||||
extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
|
|
||||||
tensorflow::argmax_float_2d_xla_impl(out, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
|
|
|
@ -33,6 +33,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
|
||||||
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
|
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||||
hlo.opcode() == HloOpcode::kGather ||
|
hlo.opcode() == HloOpcode::kGather ||
|
||||||
hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
|
hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
|
||||||
|
hlo.opcode() == HloOpcode::kReduce ||
|
||||||
hlo.opcode() == HloOpcode::kReshape ||
|
hlo.opcode() == HloOpcode::kReshape ||
|
||||||
hlo.opcode() == HloOpcode::kReverse ||
|
hlo.opcode() == HloOpcode::kReverse ||
|
||||||
hlo.opcode() == HloOpcode::kSlice ||
|
hlo.opcode() == HloOpcode::kSlice ||
|
||||||
|
@ -151,6 +152,19 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't fuse reductions over the major dimensions. These have an efficient
|
||||||
|
// lowering that's only implemented for the unfused case.
|
||||||
|
if (consumer->opcode() == HloOpcode::kReduce) {
|
||||||
|
return absl::c_linear_search(
|
||||||
|
consumer->dimensions(),
|
||||||
|
LayoutUtil::Minor(consumer->operand(0)->shape().layout(), 0));
|
||||||
|
}
|
||||||
|
if (producer->opcode() == HloOpcode::kReduce) {
|
||||||
|
return absl::c_linear_search(
|
||||||
|
producer->dimensions(),
|
||||||
|
LayoutUtil::Minor(producer->operand(0)->shape().layout(), 0));
|
||||||
|
}
|
||||||
|
|
||||||
if (consumer->IsLoopFusion()) {
|
if (consumer->IsLoopFusion()) {
|
||||||
VLOG(2) << "Fusing: consumer is a fusion node.";
|
VLOG(2) << "Fusing: consumer is a fusion node.";
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -889,6 +889,61 @@ INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation,
|
||||||
GatherLoopFusionTest,
|
GatherLoopFusionTest,
|
||||||
::testing::ValuesIn(GetGatherLoopFusionTestSpecs()),
|
::testing::ValuesIn(GetGatherLoopFusionTestSpecs()),
|
||||||
GatherLoopFusionTestSpec::Name);
|
GatherLoopFusionTestSpec::Name);
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest, NoFuseReduceMajor) {
|
||||||
|
absl::string_view module_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
add {
|
||||||
|
lhs = f32[] parameter(0)
|
||||||
|
rhs = f32[] parameter(1)
|
||||||
|
ROOT add = f32[] add(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
a = f32[50,60]{1,0} parameter(0)
|
||||||
|
b = f32[50,60]{1,0} parameter(1)
|
||||||
|
c = f32[50,60]{1,0} add(a, b)
|
||||||
|
init = f32[] constant(0)
|
||||||
|
ROOT r = f32[60]{0} reduce(c, init), dimensions={0}, to_apply=add
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
|
||||||
|
CpuInstructionFusion().Run(module.get()));
|
||||||
|
EXPECT_FALSE(fused_something);
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
Not(op::Fusion()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest, FuseReduceMinor) {
|
||||||
|
absl::string_view module_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
add {
|
||||||
|
lhs = f32[] parameter(0)
|
||||||
|
rhs = f32[] parameter(1)
|
||||||
|
ROOT add = f32[] add(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
a = f32[50,60]{1,0} parameter(0)
|
||||||
|
b = f32[50,60]{1,0} parameter(1)
|
||||||
|
c = f32[50,60]{1,0} add(a, b)
|
||||||
|
init = f32[] constant(0)
|
||||||
|
ROOT r = f32[] reduce(c, init), dimensions={0,1}, to_apply=add
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
|
||||||
|
CpuInstructionFusion().Run(module.get()));
|
||||||
|
EXPECT_TRUE(fused_something);
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion());
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
@ -3125,7 +3125,7 @@ Status IrEmitter::EmitTargetElementLoop(
|
||||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
|
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
|
||||||
llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
|
llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
|
||||||
|
|
||||||
if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() ||
|
if (target_shape.IsTuple() && (target_op->opcode() == HloOpcode::kFusion ||
|
||||||
target_op->opcode() == HloOpcode::kReduce)) {
|
target_op->opcode() == HloOpcode::kReduce)) {
|
||||||
// For multiple outputs fusion, we need to emit each operand and the root.
|
// For multiple outputs fusion, we need to emit each operand and the root.
|
||||||
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
|
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
|
||||||
|
|
|
@ -169,11 +169,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
||||||
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
|
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
|
||||||
cshape,
|
cshape,
|
||||||
builder.AddInstruction(HloInstruction::CreateReshape(
|
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {6, 1}), concatenate)),
|
ShapeUtil::MakeShape(F32, {1, 6}), concatenate)),
|
||||||
/*init_value=*/
|
/*init_value=*/
|
||||||
builder.AddInstruction(
|
builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
|
||||||
/*dimensions_to_reduce=*/{1}, add_f32));
|
/*dimensions_to_reduce=*/{0}, add_f32));
|
||||||
|
|
||||||
auto exp = builder.AddInstruction(
|
auto exp = builder.AddInstruction(
|
||||||
HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));
|
HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));
|
||||||
|
|
Loading…
Reference in New Issue