[XLA:CPU] Fuse non-vectorized reduces and remove argmax customcall
Fusing the variadic reduce version of argmax is not slower than the custom implementation on our benchmarks, so just do that instead. Saves a bit of space for everyone not using argmax. We could also fuse into the vectorized code, but that's a bit more involved and I'm not sure yet whether I see enough speedup to warrant that extra complexity. PiperOrigin-RevId: 254187396
This commit is contained in:
parent
9385325594
commit
af4eb9c864
|
@ -34,7 +34,6 @@ cc_library(
|
|||
"//tensorflow/compiler/tf2xla:tf2xla_proto",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
|
|
|
@ -286,8 +286,6 @@ def tf_library(
|
|||
] or []) + (include_standard_runtime_deps and [
|
||||
# TODO(cwhipkey): only depend on kernel code that the model actually
|
||||
# needed.
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
|
|
|
@ -51,7 +51,6 @@ cc_library(
|
|||
deps = [
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
|
|
|
@ -40,7 +40,6 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
|
@ -108,7 +107,6 @@ cc_library(
|
|||
":tf2xla_proto",
|
||||
":tf2xla_util",
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
|
|
|
@ -345,50 +345,3 @@ tf_kernel_library(
|
|||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Kernels that only work on CPU, because they use XLA custom calls.
|
||||
# Only link this when using the CPU backend for XLA.
|
||||
tf_kernel_library(
|
||||
name = "xla_cpu_only_ops",
|
||||
srcs = ["index_ops_cpu.cc"],
|
||||
deps = [
|
||||
":index_ops_kernel_argmax_float_1d",
|
||||
":index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_bounds_check",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "index_ops_kernel_argmax_float_1d",
|
||||
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "index_ops_kernel_argmax_float_2d",
|
||||
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -1,142 +0,0 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Native XLA implementations of indexing ops.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// The logic below uses a custom-call to implement argmax when possible. When
|
||||
// custom-call is not allowed or input shapes are not supported, this kernel
|
||||
// falls back to using XLA HLO native ArgMax.
|
||||
//
|
||||
// Also see b/29507024 for first-class XLA support for indexing ops.
|
||||
class ArgMaxCustomCallOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ArgMaxCustomCallOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
const TensorShape dimension_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape),
|
||||
errors::InvalidArgument(
|
||||
"dim must be a scalar, but received tensor of shape: ",
|
||||
dimension_shape.DebugString()));
|
||||
|
||||
// We require that the dimension argument is a constant, since it lets us
|
||||
// dispatch to a specialized custom-call function without any run-time
|
||||
// overhead, when compiling ahead-of-time.
|
||||
int64 dim;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim));
|
||||
|
||||
const int input_dims = input_shape.dims();
|
||||
const int axis = dim < 0 ? dim + input_dims : dim;
|
||||
OP_REQUIRES(ctx, axis >= 0 && axis < input_dims,
|
||||
errors::InvalidArgument("Expected dimension in the range [",
|
||||
-input_dims, ", ", input_dims,
|
||||
"), but got ", dim));
|
||||
|
||||
const int64 axis_size = input_shape.dim_size(axis);
|
||||
OP_REQUIRES(ctx, axis_size > 0,
|
||||
errors::InvalidArgument(
|
||||
"Reduction axis ", dim,
|
||||
" is empty in shape: ", input_shape.DebugString()));
|
||||
|
||||
const DataType dtype = output_type(0);
|
||||
xla::PrimitiveType output_type;
|
||||
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type));
|
||||
|
||||
// Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input
|
||||
// shape isn't supported.
|
||||
if (!ctx->compiler()->options().allow_cpu_custom_calls ||
|
||||
(input_dims != 1 && input_dims != 2)) {
|
||||
xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis);
|
||||
ctx->SetOutput(0, output);
|
||||
return;
|
||||
}
|
||||
|
||||
xla::XlaOp output;
|
||||
// The output shape is the input shape contracted along axis.
|
||||
TensorShape output_shape;
|
||||
for (int d = 0; d < input_shape.dims() - 1; ++d) {
|
||||
output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
|
||||
}
|
||||
|
||||
xla::XlaBuilder& b = *ctx->builder();
|
||||
|
||||
// XLA passes <out> to the function, so it is not included here.
|
||||
std::vector<xla::XlaOp> args;
|
||||
args.push_back(ctx->Input(0));
|
||||
args.push_back(xla::ConstantLiteral(
|
||||
&b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
|
||||
if (input_shape.dims() > 1) {
|
||||
// Don't bother passing the output shape and dim for the 1d case, since
|
||||
// the shape is always a scalar and the dim is always 0.
|
||||
args.push_back(xla::ConstantLiteral(
|
||||
&b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
|
||||
args.push_back(
|
||||
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(axis)));
|
||||
}
|
||||
|
||||
// The argmax function expects row-major layout.
|
||||
xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
xla::S64, output_shape.dim_sizes());
|
||||
std::vector<xla::Shape> arg_shapes;
|
||||
for (const xla::XlaOp& arg : args) {
|
||||
auto shape_status = b.GetShape(arg);
|
||||
OP_REQUIRES_OK(ctx, shape_status.status());
|
||||
xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
|
||||
*arg_shape.mutable_layout() =
|
||||
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
|
||||
arg_shapes.push_back(std::move(arg_shape));
|
||||
}
|
||||
|
||||
// Tell XLA to call the custom code, defined in
|
||||
// index_ops_kernel_argmax_float_{1, 2}d.cc.
|
||||
if (input_dims == 1) {
|
||||
output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
|
||||
xla_shape, arg_shapes);
|
||||
} else {
|
||||
output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
|
||||
xla_shape, arg_shapes);
|
||||
}
|
||||
output = xla::ConvertElementType(output, output_type);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ArgMax")
|
||||
.TypeConstraint("T", DT_FLOAT)
|
||||
.Device(DEVICE_CPU_XLA_JIT)
|
||||
.CompileTimeConstantInput("dimension"),
|
||||
ArgMaxCustomCallOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
|
@ -1,52 +0,0 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) {
|
||||
// Data is managed by the JIT code so msan can't tell it's initialized.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*));
|
||||
|
||||
float* input = static_cast<float*>(data[0]);
|
||||
int64 input_size = *static_cast<int64*>(data[1]);
|
||||
|
||||
Eigen::DSizes<Eigen::DenseIndex, 1> in_eig_sizes(input_size);
|
||||
TTypes<float, 1>::ConstTensor in_eig(input, in_eig_sizes);
|
||||
|
||||
Eigen::DSizes<Eigen::DenseIndex, 0> out_eig_sizes;
|
||||
int64* out_t = static_cast<int64*>(out);
|
||||
TTypes<int64, 0>::Tensor out_eig(out_t, out_eig_sizes);
|
||||
|
||||
out_eig = in_eig.argmax(0).cast<int64>();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// Implements argmax on CPU. This is called by an XLA custom call, set up by
|
||||
// index_ops.cc.
|
||||
extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
|
||||
tensorflow::argmax_float_1d_xla_impl(out, data);
|
||||
}
|
||||
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
|
|
@ -1,54 +0,0 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) {
|
||||
// data is managed by the JIT code so msan can't tell it's initialized.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 4 * sizeof(void*));
|
||||
|
||||
float* in = static_cast<float*>(data[0]);
|
||||
int64* in_sizes = static_cast<int64*>(data[1]);
|
||||
int64* out_sizes = static_cast<int64*>(data[2]);
|
||||
int32 dim = *static_cast<int32*>(data[3]);
|
||||
|
||||
Eigen::DSizes<Eigen::DenseIndex, 2> in_eig_sizes(in_sizes[0], in_sizes[1]);
|
||||
TTypes<float, 2>::ConstTensor in_eig(in, in_eig_sizes);
|
||||
|
||||
int64* out_t = static_cast<int64*>(out);
|
||||
Eigen::DSizes<Eigen::DenseIndex, 1> out_eig_sizes(out_sizes[0]);
|
||||
TTypes<int64, 1>::Tensor out_eig(out_t, out_eig_sizes);
|
||||
|
||||
out_eig = in_eig.argmax(dim).cast<int64>();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// Implements argmax on CPU. This is called by an XLA custom call, set up by
|
||||
// index_ops.cc.
|
||||
extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
|
||||
tensorflow::argmax_float_2d_xla_impl(out, data);
|
||||
}
|
||||
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
|
|
@ -33,6 +33,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
|
|||
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||
hlo.opcode() == HloOpcode::kGather ||
|
||||
hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
|
||||
hlo.opcode() == HloOpcode::kReduce ||
|
||||
hlo.opcode() == HloOpcode::kReshape ||
|
||||
hlo.opcode() == HloOpcode::kReverse ||
|
||||
hlo.opcode() == HloOpcode::kSlice ||
|
||||
|
@ -151,6 +152,19 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
|||
}
|
||||
}
|
||||
|
||||
// Don't fuse reductions over the major dimensions. These have an efficient
|
||||
// lowering that's only implemented for the unfused case.
|
||||
if (consumer->opcode() == HloOpcode::kReduce) {
|
||||
return absl::c_linear_search(
|
||||
consumer->dimensions(),
|
||||
LayoutUtil::Minor(consumer->operand(0)->shape().layout(), 0));
|
||||
}
|
||||
if (producer->opcode() == HloOpcode::kReduce) {
|
||||
return absl::c_linear_search(
|
||||
producer->dimensions(),
|
||||
LayoutUtil::Minor(producer->operand(0)->shape().layout(), 0));
|
||||
}
|
||||
|
||||
if (consumer->IsLoopFusion()) {
|
||||
VLOG(2) << "Fusing: consumer is a fusion node.";
|
||||
return true;
|
||||
|
|
|
@ -889,6 +889,61 @@ INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation,
|
|||
GatherLoopFusionTest,
|
||||
::testing::ValuesIn(GetGatherLoopFusionTestSpecs()),
|
||||
GatherLoopFusionTestSpec::Name);
|
||||
|
||||
TEST_F(InstructionFusionTest, NoFuseReduceMajor) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
a = f32[50,60]{1,0} parameter(0)
|
||||
b = f32[50,60]{1,0} parameter(1)
|
||||
c = f32[50,60]{1,0} add(a, b)
|
||||
init = f32[] constant(0)
|
||||
ROOT r = f32[60]{0} reduce(c, init), dimensions={0}, to_apply=add
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
|
||||
CpuInstructionFusion().Run(module.get()));
|
||||
EXPECT_FALSE(fused_something);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
Not(op::Fusion()));
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseReduceMinor) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
a = f32[50,60]{1,0} parameter(0)
|
||||
b = f32[50,60]{1,0} parameter(1)
|
||||
c = f32[50,60]{1,0} add(a, b)
|
||||
init = f32[] constant(0)
|
||||
ROOT r = f32[] reduce(c, init), dimensions={0,1}, to_apply=add
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
|
||||
CpuInstructionFusion().Run(module.get()));
|
||||
EXPECT_TRUE(fused_something);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion());
|
||||
}
|
||||
} // namespace
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
|
|
@ -3125,7 +3125,7 @@ Status IrEmitter::EmitTargetElementLoop(
|
|||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
|
||||
llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
|
||||
|
||||
if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() ||
|
||||
if (target_shape.IsTuple() && (target_op->opcode() == HloOpcode::kFusion ||
|
||||
target_op->opcode() == HloOpcode::kReduce)) {
|
||||
// For multiple outputs fusion, we need to emit each operand and the root.
|
||||
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
|
||||
|
|
|
@ -169,11 +169,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
|
|||
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
cshape,
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(F32, {6, 1}), concatenate)),
|
||||
ShapeUtil::MakeShape(F32, {1, 6}), concatenate)),
|
||||
/*init_value=*/
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
|
||||
/*dimensions_to_reduce=*/{1}, add_f32));
|
||||
/*dimensions_to_reduce=*/{0}, add_f32));
|
||||
|
||||
auto exp = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));
|
||||
|
|
Loading…
Reference in New Issue