From af4eb9c864563a98cd12c2f731b06b722f17141d Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 20 Jun 2019 06:51:42 -0700 Subject: [PATCH] [XLA:CPU] Fuse non-vectorized reduces and remove argmax customcall Fusing the variadic reduce version of argmax is not slower than the custom implementation on our benchmarks, so just do that instead. Saves a bit of space for everyone not using argmax. We could also fuse into the vectorized code, but that's a bit more involved and I'm not sure yet whether I see enough speedup to warrant that extra complexity. PiperOrigin-RevId: 254187396 --- tensorflow/compiler/aot/BUILD | 1 - tensorflow/compiler/aot/tfcompile.bzl | 2 - tensorflow/compiler/jit/BUILD | 1 - tensorflow/compiler/tf2xla/BUILD | 2 - tensorflow/compiler/tf2xla/kernels/BUILD | 47 ------ .../compiler/tf2xla/kernels/index_ops_cpu.cc | 142 ------------------ .../index_ops_kernel_argmax_float_1d.cc | 52 ------- .../index_ops_kernel_argmax_float_2d.cc | 54 ------- .../xla/service/cpu/cpu_instruction_fusion.cc | 14 ++ .../cpu/cpu_instruction_fusion_test.cc | 55 +++++++ .../compiler/xla/service/cpu/ir_emitter.cc | 2 +- .../xla/service/cpu/tests/cpu_fusion_test.cc | 4 +- 12 files changed, 72 insertions(+), 304 deletions(-) delete mode 100644 tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc delete mode 100644 tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc delete mode 100644 tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 22c360fe765..f871115a131 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -34,7 +34,6 @@ cc_library( "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:cpu_function_runtime", diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index e7f3c0aebdd..d9f871dc2e5 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -286,8 +286,6 @@ def tf_library( ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually # needed. - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort", "//tensorflow/compiler/xla/service/cpu:runtime_matmul", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bed17835cf4..9db30b2b9bf 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -51,7 +51,6 @@ cc_library( deps = [ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", - "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 82a80daf4f0..fb256a6ca70 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -40,7 +40,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -108,7 +107,6 @@ cc_library( ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:xla_computation", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 1b30fba0645..9818fbefd14 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -345,50 +345,3 @@ tf_kernel_library( ], alwayslink = 1, ) - -# Kernels that only work on CPU, because they use XLA custom calls. -# Only link this when using the CPU backend for XLA. -tf_kernel_library( - name = "xla_cpu_only_ops", - srcs = ["index_ops_cpu.cc"], - deps = [ - ":index_ops_kernel_argmax_float_1d", - ":index_ops_kernel_argmax_float_2d", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/core:framework", - "//tensorflow/core:framework_bounds_check", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "index_ops_kernel_argmax_float_1d", - srcs = ["index_ops_kernel_argmax_float_1d.cc"], - copts = tf_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/xla/service:custom_call_target_registry", - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], - alwayslink = 1, -) - -cc_library( - name = "index_ops_kernel_argmax_float_2d", - srcs = ["index_ops_kernel_argmax_float_2d.cc"], - copts = tf_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/xla/service:custom_call_target_registry", - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc deleted file mode 100644 index e4bbdef6480..00000000000 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Native XLA implementations of indexing ops. - -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" - -namespace tensorflow { -namespace { - -// The logic below uses a custom-call to implement argmax when possible. When -// custom-call is not allowed or input shapes are not supported, this kernel -// falls back to using XLA HLO native ArgMax. -// -// Also see b/29507024 for first-class XLA support for indexing ops. -class ArgMaxCustomCallOp : public XlaOpKernel { - public: - explicit ArgMaxCustomCallOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - - void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape dimension_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape), - errors::InvalidArgument( - "dim must be a scalar, but received tensor of shape: ", - dimension_shape.DebugString())); - - // We require that the dimension argument is a constant, since it lets us - // dispatch to a specialized custom-call function without any run-time - // overhead, when compiling ahead-of-time. - int64 dim; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim)); - - const int input_dims = input_shape.dims(); - const int axis = dim < 0 ? dim + input_dims : dim; - OP_REQUIRES(ctx, axis >= 0 && axis < input_dims, - errors::InvalidArgument("Expected dimension in the range [", - -input_dims, ", ", input_dims, - "), but got ", dim)); - - const int64 axis_size = input_shape.dim_size(axis); - OP_REQUIRES(ctx, axis_size > 0, - errors::InvalidArgument( - "Reduction axis ", dim, - " is empty in shape: ", input_shape.DebugString())); - - const DataType dtype = output_type(0); - xla::PrimitiveType output_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type)); - - // Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input - // shape isn't supported. - if (!ctx->compiler()->options().allow_cpu_custom_calls || - (input_dims != 1 && input_dims != 2)) { - xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis); - ctx->SetOutput(0, output); - return; - } - - xla::XlaOp output; - // The output shape is the input shape contracted along axis. - TensorShape output_shape; - for (int d = 0; d < input_shape.dims() - 1; ++d) { - output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1)); - } - - xla::XlaBuilder& b = *ctx->builder(); - - // XLA passes to the function, so it is not included here. - std::vector args; - args.push_back(ctx->Input(0)); - args.push_back(xla::ConstantLiteral( - &b, xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); - if (input_shape.dims() > 1) { - // Don't bother passing the output shape and dim for the 1d case, since - // the shape is always a scalar and the dim is always 0. - args.push_back(xla::ConstantLiteral( - &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); - args.push_back( - xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(axis))); - } - - // The argmax function expects row-major layout. - xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( - xla::S64, output_shape.dim_sizes()); - std::vector arg_shapes; - for (const xla::XlaOp& arg : args) { - auto shape_status = b.GetShape(arg); - OP_REQUIRES_OK(ctx, shape_status.status()); - xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); - *arg_shape.mutable_layout() = - xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank()); - arg_shapes.push_back(std::move(arg_shape)); - } - - // Tell XLA to call the custom code, defined in - // index_ops_kernel_argmax_float_{1, 2}d.cc. - if (input_dims == 1) { - output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, - xla_shape, arg_shapes); - } else { - output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, - xla_shape, arg_shapes); - } - output = xla::ConvertElementType(output, output_type); - ctx->SetOutput(0, output); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp); -}; - -REGISTER_XLA_OP(Name("ArgMax") - .TypeConstraint("T", DT_FLOAT) - .Device(DEVICE_CPU_XLA_JIT) - .CompileTimeConstantInput("dimension"), - ArgMaxCustomCallOp); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc deleted file mode 100644 index 379c8d4ec0c..00000000000 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/platform/dynamic_annotations.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { - // Data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*)); - - float* input = static_cast(data[0]); - int64 input_size = *static_cast(data[1]); - - Eigen::DSizes in_eig_sizes(input_size); - TTypes::ConstTensor in_eig(input, in_eig_sizes); - - Eigen::DSizes out_eig_sizes; - int64* out_t = static_cast(out); - TTypes::Tensor out_eig(out_t, out_eig_sizes); - - out_eig = in_eig.argmax(0).cast(); -} - -} // namespace tensorflow - -// Implements argmax on CPU. This is called by an XLA custom call, set up by -// index_ops.cc. -extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { - tensorflow::argmax_float_1d_xla_impl(out, data); -} - -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc deleted file mode 100644 index 6e1c1226321..00000000000 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/platform/dynamic_annotations.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 4 * sizeof(void*)); - - float* in = static_cast(data[0]); - int64* in_sizes = static_cast(data[1]); - int64* out_sizes = static_cast(data[2]); - int32 dim = *static_cast(data[3]); - - Eigen::DSizes in_eig_sizes(in_sizes[0], in_sizes[1]); - TTypes::ConstTensor in_eig(in, in_eig_sizes); - - int64* out_t = static_cast(out); - Eigen::DSizes out_eig_sizes(out_sizes[0]); - TTypes::Tensor out_eig(out_t, out_eig_sizes); - - out_eig = in_eig.argmax(dim).cast(); -} - -} // namespace tensorflow - -// Implements argmax on CPU. This is called by an XLA custom call, set up by -// index_ops.cc. -extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { - tensorflow::argmax_float_2d_xla_impl(out, data); -} - -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index a37b5e9e51b..6620a9620b5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -33,6 +33,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kReverse || hlo.opcode() == HloOpcode::kSlice || @@ -151,6 +152,19 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } } + // Don't fuse reductions over the major dimensions. These have an efficient + // lowering that's only implemented for the unfused case. + if (consumer->opcode() == HloOpcode::kReduce) { + return absl::c_linear_search( + consumer->dimensions(), + LayoutUtil::Minor(consumer->operand(0)->shape().layout(), 0)); + } + if (producer->opcode() == HloOpcode::kReduce) { + return absl::c_linear_search( + producer->dimensions(), + LayoutUtil::Minor(producer->operand(0)->shape().layout(), 0)); + } + if (consumer->IsLoopFusion()) { VLOG(2) << "Fusing: consumer is a fusion node."; return true; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index e1f5ab11456..2492a1db12a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -889,6 +889,61 @@ INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest, ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), GatherLoopFusionTestSpec::Name); + +TEST_F(InstructionFusionTest, NoFuseReduceMajor) { + absl::string_view module_string = R"( +HloModule module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[50,60]{1,0} parameter(1) + c = f32[50,60]{1,0} add(a, b) + init = f32[] constant(0) + ROOT r = f32[60]{0} reduce(c, init), dimensions={0}, to_apply=add +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + +TEST_F(InstructionFusionTest, FuseReduceMinor) { + absl::string_view module_string = R"( +HloModule module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[50,60]{1,0} parameter(1) + c = f32[50,60]{1,0} add(a, b) + init = f32[] constant(0) + ROOT r = f32[] reduce(c, init), dimensions={0,1}, to_apply=add +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_TRUE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion()); +} } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 06ea62d552c..cfe76fee3e5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -3125,7 +3125,7 @@ Status IrEmitter::EmitTargetElementLoop( TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); llvm_ir::IrArray target_array = GetIrArrayFor(target_op); - if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() || + if (target_shape.IsTuple() && (target_op->opcode() == HloOpcode::kFusion || target_op->opcode() == HloOpcode::kReduce)) { // For multiple outputs fusion, we need to emit each operand and the root. TF_RET_CHECK(num_dynamic_loop_bounds_ == 0); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index a72ebe2beea..9eb14543a87 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -169,11 +169,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( cshape, builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {6, 1}), concatenate)), + ShapeUtil::MakeShape(F32, {1, 6}), concatenate)), /*init_value=*/ builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), - /*dimensions_to_reduce=*/{1}, add_f32)); + /*dimensions_to_reduce=*/{0}, add_f32)); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));