Variadic reduce implementation on CPU
Implements variadic reduce on the CPU backend. Before this change, thread-local functions could only return scalars. The biggest part of this change is allowing thread-local functions to return tuples of scalars, which required changes to function generation, allocation of space for the returned values on the callers side, and generating the function epilogue. PiperOrigin-RevId: 239022395
This commit is contained in:
parent
1a690e6c75
commit
803b0f5151
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "llvm/IR/Instructions.h"
|
#include "llvm/IR/Instructions.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
@ -135,10 +136,19 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
|||||||
};
|
};
|
||||||
case HloOpcode::kReduce:
|
case HloOpcode::kReduce:
|
||||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||||
|
auto reduce_instr = Cast<HloReduceInstruction>(hlo);
|
||||||
|
std::vector<llvm_ir::ElementGenerator> input_generators;
|
||||||
|
for (const HloInstruction* instr : reduce_instr->inputs()) {
|
||||||
|
input_generators.push_back(operand_to_generator.at(instr));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llvm_ir::ElementGenerator> initial_value_generators;
|
||||||
|
for (const HloInstruction* instr : reduce_instr->init_values()) {
|
||||||
|
initial_value_generators.push_back(operand_to_generator.at(instr));
|
||||||
|
}
|
||||||
return ir_emitter_->EmitElementalReduce(
|
return ir_emitter_->EmitElementalReduce(
|
||||||
Cast<HloReduceInstruction>(hlo),
|
reduce_instr, std::move(input_generators),
|
||||||
operand_to_generator.at(hlo->operand(0)),
|
std::move(initial_value_generators), index);
|
||||||
operand_to_generator.at(hlo->operand(1)), index);
|
|
||||||
};
|
};
|
||||||
default:
|
default:
|
||||||
return ElementalIrEmitter::MakeElementGenerator(hlo,
|
return ElementalIrEmitter::MakeElementGenerator(hlo,
|
||||||
|
@ -58,8 +58,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -106,6 +108,42 @@ IrEmitter::IrEmitter(
|
|||||||
TF_CHECK_OK(s) << "Should have failed buffer assignment.";
|
TF_CHECK_OK(s) << "Should have failed buffer assignment.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
|
||||||
|
llvm::Argument* out_parameter = compute_function_->result_arg();
|
||||||
|
llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
|
||||||
|
const Shape& return_shape = computation->root_instruction()->shape();
|
||||||
|
|
||||||
|
if (ShapeUtil::IsScalar(return_shape)) {
|
||||||
|
llvm::Value* ret_value =
|
||||||
|
Load(root_value.GetBasePointer(), "load_ret_value");
|
||||||
|
Store(ret_value,
|
||||||
|
BitCast(out_parameter, root_value.GetBasePointer()->getType()));
|
||||||
|
} else {
|
||||||
|
CHECK(return_shape.IsTuple());
|
||||||
|
|
||||||
|
llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
|
||||||
|
llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
|
||||||
|
llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
|
||||||
|
|
||||||
|
for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
|
||||||
|
const Shape& element_shape = return_shape.tuple_shapes(i);
|
||||||
|
llvm::Value* destination = llvm_ir::EmitGetTupleElement(
|
||||||
|
element_shape,
|
||||||
|
/*index=*/i,
|
||||||
|
/*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
|
||||||
|
&b_);
|
||||||
|
|
||||||
|
llvm::Value* source = llvm_ir::EmitGetTupleElement(
|
||||||
|
element_shape,
|
||||||
|
/*index=*/i,
|
||||||
|
/*alignment=*/MinimumAlignmentForShape(element_shape),
|
||||||
|
root_value.GetBasePointer(), &b_);
|
||||||
|
|
||||||
|
Store(Load(source), destination);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
||||||
HloComputation* computation, const string& function_name_prefix,
|
HloComputation* computation, const string& function_name_prefix,
|
||||||
bool is_top_level_computation,
|
bool is_top_level_computation,
|
||||||
@ -143,6 +181,16 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
|
|||||||
InsertOrDie(&emitted_functions_, computation, ir_function);
|
InsertOrDie(&emitted_functions_, computation, ir_function);
|
||||||
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
|
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
|
||||||
// IR insert point.
|
// IR insert point.
|
||||||
|
|
||||||
|
// Function epilogue: copying the value over to either the return register,
|
||||||
|
// or values pointing from the return register.
|
||||||
|
const BufferAllocation* root_allocation =
|
||||||
|
computation_root_allocation_.allocation();
|
||||||
|
if (root_allocation && root_allocation->is_thread_local()) {
|
||||||
|
EmitThreadLocalFunctionEpilogue(computation);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destructor for compute_function_ emits the "ret void" instruction.
|
||||||
compute_function_.reset();
|
compute_function_.reset();
|
||||||
computation_root_allocation_ = BufferAllocation::Slice();
|
computation_root_allocation_ = BufferAllocation::Slice();
|
||||||
computation_parameter_allocations_.clear();
|
computation_parameter_allocations_.clear();
|
||||||
@ -634,7 +682,8 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
|||||||
llvm::Value* IrEmitter::EmitElementalMap(
|
llvm::Value* IrEmitter::EmitElementalMap(
|
||||||
const HloMapInstruction& map_instr,
|
const HloMapInstruction& map_instr,
|
||||||
absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
|
absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
|
||||||
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
|
return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(),
|
||||||
|
elemental_operands, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
|
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
|
||||||
@ -716,7 +765,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
|
|||||||
b_.getInt64Ty());
|
b_.getInt64Ty());
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
|
TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
|
||||||
input_generator(input_index));
|
input_generator(input_index));
|
||||||
llvm::Value* result = EmitThreadLocalCall(
|
llvm::Value* result = EmitScalarReturningThreadLocalCall(
|
||||||
*reduce_window->to_apply(), {Load(accumulator_address), input_value},
|
*reduce_window->to_apply(), {Load(accumulator_address), input_value},
|
||||||
"reducer_function");
|
"reducer_function");
|
||||||
Store(result, accumulator_address);
|
Store(result, accumulator_address);
|
||||||
@ -868,7 +917,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
|||||||
llvm::Value* operand_address =
|
llvm::Value* operand_address =
|
||||||
operand_array.EmitArrayElementAddress(operand_index, &b_);
|
operand_array.EmitArrayElementAddress(operand_index, &b_);
|
||||||
llvm::Value* operand_element = Load(operand_address);
|
llvm::Value* operand_element = Load(operand_address);
|
||||||
llvm::Value* result = EmitThreadLocalCall(
|
llvm::Value* result = EmitScalarReturningThreadLocalCall(
|
||||||
*select_and_scatter->select(),
|
*select_and_scatter->select(),
|
||||||
{Load(selected_value_address), operand_element}, "select_function");
|
{Load(selected_value_address), operand_element}, "select_function");
|
||||||
|
|
||||||
@ -903,9 +952,9 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
|||||||
selected_multi_index, output_array.GetShape(), source_index.GetType());
|
selected_multi_index, output_array.GetShape(), source_index.GetType());
|
||||||
llvm::Value* output_value =
|
llvm::Value* output_value =
|
||||||
output_array.EmitReadArrayElement(selected_index, &b_);
|
output_array.EmitReadArrayElement(selected_index, &b_);
|
||||||
llvm::Value* scatter_value =
|
llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
|
||||||
EmitThreadLocalCall(*select_and_scatter->scatter(),
|
*select_and_scatter->scatter(), {output_value, source_value},
|
||||||
{output_value, source_value}, "scatter_function");
|
"scatter_function");
|
||||||
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
|
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
|
||||||
|
|
||||||
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
|
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
|
||||||
@ -1665,6 +1714,11 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
|
|||||||
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
|
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
|
||||||
absl::Span<const int64> dimensions, HloComputation* function,
|
absl::Span<const int64> dimensions, HloComputation* function,
|
||||||
string* failure_reason) {
|
string* failure_reason) {
|
||||||
|
if (!reduce->shape().IsArray()) {
|
||||||
|
*failure_reason = "vectorization of variadic reduce not implemented";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (!ReductionPreservesLayout(*reduce)) {
|
if (!ReductionPreservesLayout(*reduce)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1813,21 +1867,39 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
|
|||||||
|
|
||||||
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
|
StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
|
||||||
const HloReduceInstruction* reduce,
|
const HloReduceInstruction* reduce,
|
||||||
const llvm_ir::ElementGenerator& input_generator,
|
std::vector<llvm_ir::ElementGenerator> input_generators,
|
||||||
const llvm_ir::ElementGenerator& initial_value_generator,
|
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
|
||||||
const llvm_ir::IrArray::Index& index) {
|
const llvm_ir::IrArray::Index& index) {
|
||||||
const HloInstruction* arg = reduce->operand(0);
|
const Shape& out_shape = reduce->shape();
|
||||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
bool is_variadic = !out_shape.IsArray();
|
||||||
|
int accumulators_count = 1;
|
||||||
|
if (is_variadic) {
|
||||||
|
CHECK(out_shape.IsTuple());
|
||||||
|
accumulators_count = out_shape.tuple_shapes_size();
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize an accumulator with init_value.
|
absl::Span<const int64> reduced_dimensions(reduce->dimensions());
|
||||||
PrimitiveType accumulator_type = reduce->shape().element_type();
|
|
||||||
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
std::vector<llvm::Value*> accumulator_addrs;
|
||||||
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
|
std::vector<llvm::Type*> accumulator_types;
|
||||||
&b_, MinimumAlignmentForPrimitiveType(accumulator_type));
|
for (int i = 0; i < accumulators_count; i++) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
const Shape& element_shape =
|
||||||
llvm::Value* const init_value,
|
is_variadic ? out_shape.tuple_shapes(i) : out_shape;
|
||||||
initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
|
PrimitiveType accumulator_type = element_shape.element_type();
|
||||||
Store(init_value, accumulator_addr);
|
llvm::Type* accumulator_llvm_type =
|
||||||
|
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
|
||||||
|
accumulator_types.push_back(accumulator_llvm_type);
|
||||||
|
|
||||||
|
// Initialize an accumulator with init_value.
|
||||||
|
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
|
accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_,
|
||||||
|
MinimumAlignmentForPrimitiveType(accumulator_type));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
llvm::Value* const init_value,
|
||||||
|
initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType())));
|
||||||
|
Store(init_value, accumulator_addr);
|
||||||
|
accumulator_addrs.push_back(accumulator_addr);
|
||||||
|
}
|
||||||
|
|
||||||
// The enclosing loops go over all the target elements. Now we have to compute
|
// The enclosing loops go over all the target elements. Now we have to compute
|
||||||
// the actual target element. For this, we build a new loop nest to iterate
|
// the actual target element. For this, we build a new loop nest to iterate
|
||||||
@ -1835,14 +1907,15 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
|
|||||||
// AddLoopsForShapeOnDimensions will return an Index where induction Value*s
|
// AddLoopsForShapeOnDimensions will return an Index where induction Value*s
|
||||||
// are placed for each dimension in dimensions, and all the rest are nullptrs.
|
// are placed for each dimension in dimensions, and all the rest are nullptrs.
|
||||||
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
|
llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
|
||||||
|
const HloInstruction* arg = reduce->operand(0);
|
||||||
std::vector<llvm::Value*> input_multi_index =
|
std::vector<llvm::Value*> input_multi_index =
|
||||||
loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
|
loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
|
||||||
"reduction_dim");
|
"reduction_dim");
|
||||||
|
|
||||||
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
|
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
|
||||||
|
|
||||||
// Build a full index for the input argument, using reduced_dims_index as the
|
// Build a full index for the input argument, using input_multi_index as the
|
||||||
// base. In reduced_dims_index only the reduction dimensions are filled in. We
|
// base. In input_multi_index only the reduction dimensions are filled in. We
|
||||||
// fill in the rest of the dimensions with induction Value*s taken from
|
// fill in the rest of the dimensions with induction Value*s taken from
|
||||||
// 'index' which iterates over the target array. See the high-level
|
// 'index' which iterates over the target array. See the high-level
|
||||||
// description in the XLA documentation for details.
|
// description in the XLA documentation for details.
|
||||||
@ -1857,23 +1930,44 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
|
|||||||
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
|
llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
|
||||||
b_.getInt64Ty());
|
b_.getInt64Ty());
|
||||||
|
|
||||||
// Apply the reduction function to the loaded value.
|
std::vector<llvm::Value*> reduction_operands;
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
|
for (llvm::Value* accum : accumulator_addrs) {
|
||||||
input_generator(input_index));
|
llvm::Value* accum_value = Load(accum);
|
||||||
llvm::Value* result = EmitThreadLocalCall(
|
reduction_operands.push_back(accum_value);
|
||||||
*reduce->to_apply(), {Load(accumulator_addr), input_element},
|
}
|
||||||
"reduce_function");
|
|
||||||
Store(result, accumulator_addr);
|
|
||||||
|
|
||||||
|
for (int i = 0; i < accumulators_count; i++) {
|
||||||
|
TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
|
||||||
|
input_generators[i](input_index));
|
||||||
|
reduction_operands.push_back(input_element);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llvm::Value*> results = EmitThreadLocalCall(
|
||||||
|
*reduce->to_apply(), reduction_operands, "reduce_function");
|
||||||
|
|
||||||
|
CHECK(results.size() == accumulators_count);
|
||||||
|
for (int i = 0; i < accumulators_count; i++) {
|
||||||
|
Store(results[i], accumulator_addrs[i]);
|
||||||
|
}
|
||||||
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
||||||
return Load(accumulator_addr);
|
|
||||||
|
if (is_variadic) {
|
||||||
|
// Emit a structure, as that what the LoopEmitter expects.
|
||||||
|
llvm::Value* returned_structure = llvm::UndefValue::get(
|
||||||
|
llvm::StructType::get(b_.getContext(), accumulator_types));
|
||||||
|
for (int i = 0; i < accumulators_count; i++) {
|
||||||
|
llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
|
||||||
|
returned_structure =
|
||||||
|
b_.CreateInsertValue(returned_structure, accumulator_value, i);
|
||||||
|
}
|
||||||
|
return returned_structure;
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(accumulator_addrs.size(), 1);
|
||||||
|
return Load(accumulator_addrs[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
||||||
// TODO(b/118333695): Support variadic reduce.
|
|
||||||
if (!reduce->shape().IsArray()) {
|
|
||||||
return Unimplemented("Variadic reduce is not supported on CPU");
|
|
||||||
}
|
|
||||||
auto arg = reduce->mutable_operand(0);
|
auto arg = reduce->mutable_operand(0);
|
||||||
auto init_value = reduce->mutable_operand(1);
|
auto init_value = reduce->mutable_operand(1);
|
||||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||||
@ -2848,15 +2942,6 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
|
|||||||
const BufferAllocation::Slice& slice, const Shape& target_shape) {
|
const BufferAllocation::Slice& slice, const Shape& target_shape) {
|
||||||
const BufferAllocation& allocation = *slice.allocation();
|
const BufferAllocation& allocation = *slice.allocation();
|
||||||
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
|
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
|
||||||
if (slice == computation_root_allocation_) {
|
|
||||||
llvm::Argument* retval = compute_function_->result_arg();
|
|
||||||
llvm::AttrBuilder attr_builder;
|
|
||||||
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
|
|
||||||
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
|
|
||||||
retval->addAttrs(attr_builder);
|
|
||||||
return retval;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto param_it =
|
auto param_it =
|
||||||
computation_parameter_allocations_.find(slice.allocation()->index());
|
computation_parameter_allocations_.find(slice.allocation()->index());
|
||||||
if (param_it != computation_parameter_allocations_.end()) {
|
if (param_it != computation_parameter_allocations_.end()) {
|
||||||
@ -2966,7 +3051,8 @@ Status IrEmitter::EmitTargetElementLoop(
|
|||||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
|
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
|
||||||
llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
|
llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
|
||||||
|
|
||||||
if (target_op->IsMultiOutputFusion()) {
|
if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() ||
|
||||||
|
target_op->opcode() == HloOpcode::kReduce)) {
|
||||||
// For multiple outputs fusion, we need to emit each operand and the root.
|
// For multiple outputs fusion, we need to emit each operand and the root.
|
||||||
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
|
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
|
||||||
std::vector<llvm_ir::IrArray> output_arrays;
|
std::vector<llvm_ir::IrArray> output_arrays;
|
||||||
@ -3048,19 +3134,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
|
|||||||
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
|
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Value* IrEmitter::EmitThreadLocalCall(
|
llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
|
||||||
|
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||||
|
absl::string_view name) {
|
||||||
|
std::vector<llvm::Value*> return_value =
|
||||||
|
EmitThreadLocalCall(callee, parameters, name);
|
||||||
|
CHECK_EQ(return_value.size(), 1);
|
||||||
|
return return_value[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
|
||||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||||
absl::string_view name) {
|
absl::string_view name) {
|
||||||
CHECK(absl::c_binary_search(thread_local_computations_, &callee));
|
CHECK(absl::c_binary_search(thread_local_computations_, &callee));
|
||||||
|
|
||||||
const Shape& return_shape = callee.root_instruction()->shape();
|
const Shape& return_shape = callee.root_instruction()->shape();
|
||||||
|
bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
|
||||||
// Lifting this restriction to allow "small" arrays should be easy. Allowing
|
bool is_tuple_of_scalars_return =
|
||||||
// larger arrays is difficult because we allocate the buffer for this return
|
return_shape.IsTuple() &&
|
||||||
// value on the stack.
|
absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
|
||||||
CHECK(ShapeUtil::IsScalar(return_shape));
|
return ShapeUtil::IsScalar(shape);
|
||||||
|
});
|
||||||
PrimitiveType return_type = return_shape.element_type();
|
CHECK(is_scalar_return || is_tuple_of_scalars_return);
|
||||||
|
|
||||||
std::vector<llvm::Value*> parameter_addrs;
|
std::vector<llvm::Value*> parameter_addrs;
|
||||||
for (llvm::Value* parameter : parameters) {
|
for (llvm::Value* parameter : parameters) {
|
||||||
@ -3071,10 +3165,30 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
|
|||||||
parameter_addrs.push_back(parameter_addr);
|
parameter_addrs.push_back(parameter_addr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::Type* return_value_buffer_type =
|
||||||
|
llvm_ir::ShapeToIrType(return_shape, module_);
|
||||||
|
std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
|
||||||
|
int retval_alignment =
|
||||||
|
is_scalar_return
|
||||||
|
? MinimumAlignmentForPrimitiveType(return_shape.element_type())
|
||||||
|
: 0;
|
||||||
llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(return_type, module_),
|
return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
|
||||||
absl::StrCat(name, "_retval_addr"), &b_,
|
|
||||||
MinimumAlignmentForPrimitiveType(return_type));
|
std::vector<llvm::Value*> allocas_for_returned_scalars;
|
||||||
|
if (is_scalar_return) {
|
||||||
|
allocas_for_returned_scalars.push_back(return_value_buffer);
|
||||||
|
} else {
|
||||||
|
constexpr int max_tuple_size = 1000;
|
||||||
|
CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
|
||||||
|
<< "Multivalue function can not return more than 1000 elements to avoid"
|
||||||
|
<< " stack smashing";
|
||||||
|
allocas_for_returned_scalars =
|
||||||
|
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
|
||||||
|
llvm_ir::IrArray tuple_array(return_value_buffer, return_shape);
|
||||||
|
|
||||||
|
EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
|
||||||
|
}
|
||||||
|
|
||||||
Call(FindOrDie(emitted_functions_, &callee),
|
Call(FindOrDie(emitted_functions_, &callee),
|
||||||
GetArrayFunctionCallArguments(
|
GetArrayFunctionCallArguments(
|
||||||
@ -3085,7 +3199,12 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
|
|||||||
llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
|
llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
|
||||||
/*profile_counters_arg=*/GetProfileCountersArgument()));
|
/*profile_counters_arg=*/GetProfileCountersArgument()));
|
||||||
|
|
||||||
return Load(return_value_buffer);
|
std::vector<llvm::Value*> returned_scalars;
|
||||||
|
returned_scalars.reserve(allocas_for_returned_scalars.size());
|
||||||
|
for (llvm::Value* addr : allocas_for_returned_scalars) {
|
||||||
|
returned_scalars.push_back(Load(addr));
|
||||||
|
}
|
||||||
|
return returned_scalars;
|
||||||
}
|
}
|
||||||
|
|
||||||
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
|
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
|
||||||
|
@ -132,8 +132,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
// Emit code to emit the element at `index` for a reduce instruction.
|
// Emit code to emit the element at `index` for a reduce instruction.
|
||||||
StatusOr<llvm::Value*> EmitElementalReduce(
|
StatusOr<llvm::Value*> EmitElementalReduce(
|
||||||
const HloReduceInstruction* reduce,
|
const HloReduceInstruction* reduce,
|
||||||
const llvm_ir::ElementGenerator& input_generator,
|
std::vector<llvm_ir::ElementGenerator> input_generators,
|
||||||
const llvm_ir::ElementGenerator& initial_value_generator,
|
std::vector<llvm_ir::ElementGenerator> initial_value_generator,
|
||||||
const llvm_ir::IrArray::Index& index);
|
const llvm_ir::IrArray::Index& index);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -197,6 +197,14 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
// Private helper to initialize an IR function for the computation.
|
// Private helper to initialize an IR function for the computation.
|
||||||
void InitializeIrFunction(const string& function_name);
|
void InitializeIrFunction(const string& function_name);
|
||||||
|
|
||||||
|
// Emits the copying epilogue for the function,
|
||||||
|
// where it copies the returned value to the reserved alloca.
|
||||||
|
// This is only necessary for thread-local functions.
|
||||||
|
// Note that since the call graph is flattened, if the same function is
|
||||||
|
// called in both thread-local and non-thread-local it would be codegen'd
|
||||||
|
// twice, and we would know whether it's thread-local at codegen time.
|
||||||
|
void EmitThreadLocalFunctionEpilogue(HloComputation* computation);
|
||||||
|
|
||||||
// Convenience functions to generate a GEP into the profile counter parameter
|
// Convenience functions to generate a GEP into the profile counter parameter
|
||||||
// which would correspond to the index for a given HLO instruction or
|
// which would correspond to the index for a given HLO instruction or
|
||||||
// computation.
|
// computation.
|
||||||
@ -267,12 +275,18 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
// Emits a call to a thread local function (e.g. to the computation nested
|
// Emits a call to a thread local function (e.g. to the computation nested
|
||||||
// within a reduce or a map). Thread local callees (by definition) only write
|
// within a reduce or a map). Thread local callees (by definition) only write
|
||||||
// to and read from thread local allocations.
|
// to and read from thread local allocations.
|
||||||
|
// Supports only functions returning scalars or tuples of scalars.
|
||||||
//
|
//
|
||||||
// `parameters` holds the *scalar values* that need to be passed to the
|
// `parameters` holds the *scalar values* that need to be passed to the
|
||||||
// callee. The return value is the scalar returned by the callee.
|
// callee. The return value is the scalar returned by the callee.
|
||||||
llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
|
std::vector<llvm::Value*> EmitThreadLocalCall(
|
||||||
absl::Span<llvm::Value* const> parameters,
|
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||||
absl::string_view name);
|
absl::string_view name);
|
||||||
|
|
||||||
|
// Similar to EmitThreadLocal, yet assumes that the function returns a scalar.
|
||||||
|
llvm::Value* EmitScalarReturningThreadLocalCall(
|
||||||
|
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||||
|
absl::string_view name);
|
||||||
|
|
||||||
// Emits a call to a "global" function (e.g. to the computation nested within
|
// Emits a call to a "global" function (e.g. to the computation nested within
|
||||||
// a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
|
// a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
|
||||||
|
@ -47,14 +47,14 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
|
|||||||
shape_(target_array.GetShape()),
|
shape_(target_array.GetShape()),
|
||||||
b_(b) {}
|
b_(b) {}
|
||||||
|
|
||||||
static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
|
static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutput(
|
||||||
const ElementGenerator& target_element_generator,
|
const ElementGenerator& target_element_generator,
|
||||||
const std::vector<IrArray>& target_arrays, llvm::IRBuilder<>* b) {
|
const std::vector<IrArray>& target_arrays, llvm::IRBuilder<>* b) {
|
||||||
return [=](const llvm_ir::IrArray::Index array_index) {
|
return [=](const llvm_ir::IrArray::Index array_index) {
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
|
TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
|
||||||
target_element_generator(array_index));
|
target_element_generator(array_index));
|
||||||
CHECK(target_element->getType()->isStructTy())
|
CHECK(target_element->getType()->isStructTy())
|
||||||
<< "This BodyEmitter is for multi-output fusion, but target element "
|
<< "This BodyEmitter is for multi-output, but target element "
|
||||||
"generator does not produce values of struct type.";
|
"generator does not produce values of struct type.";
|
||||||
CHECK_EQ(target_element->getType()->getStructNumElements(),
|
CHECK_EQ(target_element->getType()->getStructNumElements(),
|
||||||
target_arrays.size());
|
target_arrays.size());
|
||||||
@ -70,7 +70,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
|
|||||||
LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
|
LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
|
||||||
absl::Span<const IrArray> target_arrays,
|
absl::Span<const IrArray> target_arrays,
|
||||||
llvm::IRBuilder<>* b)
|
llvm::IRBuilder<>* b)
|
||||||
: body_emitter_(MakeBodyEmitterForMultiOutputFusion(
|
: body_emitter_(MakeBodyEmitterForMultiOutput(
|
||||||
target_element_generator,
|
target_element_generator,
|
||||||
std::vector<IrArray>(target_arrays.begin(), target_arrays.end()), b)),
|
std::vector<IrArray>(target_arrays.begin(), target_arrays.end()), b)),
|
||||||
shape_(target_arrays[0].GetShape()),
|
shape_(target_arrays[0].GetShape()),
|
||||||
|
@ -91,6 +91,32 @@ void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
|
|||||||
llvm_ir::EmitTuple(tuple, buffer_ptrs, b);
|
llvm_ir::EmitTuple(tuple, buffer_ptrs, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<llvm::Value*> EmitTupleAllocasAtFunctionEntry(
|
||||||
|
const Shape& tuple_shape, llvm::IRBuilder<>* b) {
|
||||||
|
llvm::Module* module = b->GetInsertBlock()->getModule();
|
||||||
|
|
||||||
|
llvm::IRBuilder<>::InsertPointGuard guard(*b);
|
||||||
|
llvm::Function* function = b->GetInsertBlock()->getParent();
|
||||||
|
b->SetInsertPoint(&function->getEntryBlock(),
|
||||||
|
function->getEntryBlock().getFirstInsertionPt());
|
||||||
|
CHECK(tuple_shape.IsTuple());
|
||||||
|
int tuple_size = tuple_shape.tuple_shapes_size();
|
||||||
|
|
||||||
|
std::vector<llvm::Value*> generated_allocas;
|
||||||
|
for (int i = 0; i < tuple_size; i++) {
|
||||||
|
const Shape& element_shape = tuple_shape.tuple_shapes(i);
|
||||||
|
CHECK(ShapeUtil::IsScalar(element_shape));
|
||||||
|
llvm::Type* type =
|
||||||
|
llvm_ir::PrimitiveTypeToIrType(element_shape.element_type(), module);
|
||||||
|
llvm::AllocaInst* alloca = b->CreateAlloca(
|
||||||
|
type,
|
||||||
|
/*ArraySize=*/nullptr, AsStringRef(absl::StrCat("tuple_element_", i)));
|
||||||
|
generated_allocas.push_back(alloca);
|
||||||
|
}
|
||||||
|
|
||||||
|
return generated_allocas;
|
||||||
|
}
|
||||||
|
|
||||||
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||||
int alignment, llvm::Value* operand,
|
int alignment, llvm::Value* operand,
|
||||||
llvm::IRBuilder<>* b) {
|
llvm::IRBuilder<>* b) {
|
||||||
|
@ -68,6 +68,12 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
|
|||||||
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
|
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
|
||||||
llvm::IRBuilder<>* b);
|
llvm::IRBuilder<>* b);
|
||||||
|
|
||||||
|
// Emits one alloca for each element in the tuple of shape tuple_shape,
|
||||||
|
// returns the emitted allocas.
|
||||||
|
// Precondition: tuple_shape should be a tuple of scalars.
|
||||||
|
std::vector<llvm::Value*> EmitTupleAllocasAtFunctionEntry(
|
||||||
|
const Shape& tuple_shape, llvm::IRBuilder<>* b);
|
||||||
|
|
||||||
// Similar to EmitTuple above, except that the output buffers are provided in
|
// Similar to EmitTuple above, except that the output buffers are provided in
|
||||||
// the form of IrArray.
|
// the form of IrArray.
|
||||||
void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
|
void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
|
||||||
|
@ -1168,6 +1168,7 @@ xla_test(
|
|||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
@ -1002,5 +1003,100 @@ XLA_TEST_F(ReduceTest, R0ReduceInDisguise) {
|
|||||||
ErrorSpec(0.001));
|
ErrorSpec(0.001));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class VariadicReduceTest : public HloTestBase {};
|
||||||
|
|
||||||
|
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R3x2_to_R1x2_simple)) {
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule Reduce_R3x2_to_R1x2_simple
|
||||||
|
|
||||||
|
add {
|
||||||
|
op1 = f32[] parameter(0)
|
||||||
|
op2 = f32[] parameter(1)
|
||||||
|
acc1 = f32[] parameter(2)
|
||||||
|
acc2 = f32[] parameter(3)
|
||||||
|
out1 = f32[] add(acc1, op1)
|
||||||
|
out2 = f32[] add(acc2, op2)
|
||||||
|
ROOT result = (f32[], f32[]) tuple(out1, out2)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
inp1 = f32[10,20,3] parameter(0)
|
||||||
|
inp2 = f32[10,20,3] parameter(1)
|
||||||
|
zero = f32[] constant(0)
|
||||||
|
|
||||||
|
ROOT out = (f32[10], f32[10]) reduce(inp1, inp2, zero, zero),
|
||||||
|
dimensions={1,2},
|
||||||
|
to_apply=add
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_simple)) {
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule Reduce_R1x2_to_R0x2_simple
|
||||||
|
|
||||||
|
add {
|
||||||
|
op1 = f32[] parameter(0)
|
||||||
|
op2 = f32[] parameter(1)
|
||||||
|
acc1 = f32[] parameter(2)
|
||||||
|
acc2 = f32[] parameter(3)
|
||||||
|
out1 = f32[] add(acc1, op1)
|
||||||
|
out2 = f32[] add(acc2, op2)
|
||||||
|
ROOT result = (f32[], f32[]) tuple(out1, out2)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
inp1 = f32[100] parameter(0)
|
||||||
|
inp2 = f32[100] parameter(1)
|
||||||
|
zero = f32[] constant(0)
|
||||||
|
|
||||||
|
ROOT out = (f32[], f32[]) reduce(inp1, inp2, zero, zero),
|
||||||
|
dimensions={0},
|
||||||
|
to_apply=add
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(VariadicReduceTest, DISABLED_ON_GPU(Reduce_R1x2_to_R0x2_argmax)) {
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule Reduce_R1x2_to_R0x2_argmax
|
||||||
|
|
||||||
|
argmax {
|
||||||
|
running_max = u32[] parameter(0)
|
||||||
|
running_max_idx = u32[] parameter(1)
|
||||||
|
current_value = u32[] parameter(2)
|
||||||
|
current_value_idx = u32[] parameter(3)
|
||||||
|
|
||||||
|
current = (u32[], u32[]) tuple(running_max, running_max_idx)
|
||||||
|
potential = (u32[], u32[]) tuple(current_value, current_value_idx)
|
||||||
|
|
||||||
|
cmp_code = pred[] compare(current_value, running_max), direction=GT
|
||||||
|
|
||||||
|
new_max = u32[] select(cmp_code, current_value, running_max)
|
||||||
|
new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
|
||||||
|
|
||||||
|
ROOT out = (u32[], u32[]) tuple(new_max, new_idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
input = u32[100] parameter(0)
|
||||||
|
idxs = u32[100]{0} iota(), iota_dimension=0
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
zero_idx = u32[] constant(0)
|
||||||
|
|
||||||
|
ROOT out = (u32[], u32[]) reduce(
|
||||||
|
input, idxs, zero, zero_idx),
|
||||||
|
dimensions={0},
|
||||||
|
to_apply=%argmax
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user