[XLA:CPU] Move elemental conv emitter to generic code path so it can share the multiply-add logic with dot

This fixes some edge cases when it comes to complex numbers and also would
allow using it for ints if we want that. GPU doesn't use this code path.

PiperOrigin-RevId: 351549079
Change-Id: I7a2f9e9758e270b62814c7e3e0419342c5b58196
This commit is contained in:
Benjamin Kramer 2021-01-13 03:01:35 -08:00 committed by TensorFlower Gardener
parent f2a0826b7d
commit 15e0f259b1
8 changed files with 184 additions and 191 deletions

View File

@ -4203,12 +4203,12 @@ cc_library(
deps = [ deps = [
":hlo", ":hlo",
":hlo_casting_utils", ":hlo_casting_utils",
":hlo_module_config",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",

View File

@ -105,15 +105,5 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
return result; return result;
} }
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
return ir_emitter_->EmitElementalConvolution(
Cast<HloConvolutionInstruction>(hlo),
operand_to_generator.at(hlo->operand(0)),
operand_to_generator.at(hlo->operand(1)), index);
}
} // namespace cpu } // namespace cpu
} // namespace xla } // namespace xla

View File

@ -40,10 +40,6 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
llvm::Value* rhs) override; llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type, StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override; llvm::Value* value) override;
StatusOr<llvm::Value*> EmitConvolution(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) override;
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall( StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters, const HloComputation& callee, absl::Span<llvm::Value* const> parameters,

View File

@ -856,152 +856,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
hlo_module_config_, target_machine_features_); hlo_module_config_, target_machine_features_);
} }
StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
const HloConvolutionInstruction* convolution,
const llvm_ir::ElementGenerator& input_generator,
const llvm_ir::ElementGenerator& kernel_generator,
const llvm_ir::IrArray::Index& index) {
const HloInstruction* lhs = convolution->operand(0);
const HloInstruction* rhs = convolution->operand(1);
const Window& window = convolution->window();
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
int num_spatial_dims = dnums.output_spatial_dimensions_size();
std::vector<llvm::Value*> output_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
}
llvm::Value* output_feature = index[dnums.output_feature_dimension()];
llvm::Value* batch = index[dnums.output_batch_dimension()];
// We will accumulate the products into this sum to calculate the output entry
// at the given index.
PrimitiveType lhs_element_type = lhs->shape().element_type();
llvm::Type* lhs_llvm_type =
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
// Upcast the accumulator to F32 from F16 for increased precision.
llvm::Type* accumulator_type =
lhs_element_type == F16 ? b_.getFloatTy() : lhs_llvm_type;
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
accumulator_type, "convolution_sum_address", &b_,
MinimumAlignmentForPrimitiveType(lhs_element_type));
llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
Store(constant_zero, sum_address);
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_spatial[i] =
loops
.AddLoop(
0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
absl::StrCat("k", i))
->GetIndVarValue();
}
llvm::Value* input_feature =
loops
.AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
"iz")
->GetIndVarValue();
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
// Calculate the spatial index in the input array, taking striding, dilation
// and padding into account. An index in the padding will be out of the bounds
// of the array.
const auto calculate_input_index = [this](llvm::Value* output_index,
llvm::Value* kernel_index,
const WindowDimension& window_dim) {
llvm::Value* strided_index =
NSWMul(output_index, b_.getInt64(window_dim.stride()));
llvm::Value* dilated_kernel_index =
NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
b_.getInt64(window_dim.padding_low()));
};
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] = calculate_input_index(
output_spatial[i], kernel_spatial[i], window.dimensions(i));
}
// We need to check if 0 <= input dim < bound, as otherwise we are in the
// padding so that we can skip the computation. That is equivalent to input
// dim < bound as an *unsigned* comparison, since a negative value will wrap
// to a large positive value. The input dim is dilated, so we need to dilate
// the bound as well to match.
// Also need to check that the input coordinates are not in one of the
// holes created by base dilation.
const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
return ICmpEQ(remainder, b_.getInt64(0));
};
llvm::Value* in_bounds_condition = b_.getInt1(true);
for (int i = 0; i < num_spatial_dims; ++i) {
llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
window.dimensions(i).base_dilation()));
llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
llvm::Value* dim_not_in_hole =
not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
in_bounds_condition = And(in_bounds_condition, dim_ok);
}
// Now we need to map the dilated base coordinates back to the actual
// data indices on the lhs.
const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
return SDiv(input_index, b_.getInt64(base_dilation));
};
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] =
undilate(input_spatial[i], window.dimensions(i).base_dilation());
}
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
SetToFirstInsertPoint(if_data.true_block, &b_);
// We are not in the padding, so carry out the computation.
int num_dims = num_spatial_dims + 2;
std::vector<llvm::Value*> input_multi_index(num_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
}
input_multi_index[dnums.input_feature_dimension()] = input_feature;
input_multi_index[dnums.input_batch_dimension()] = batch;
std::vector<llvm::Value*> kernel_multi_index(num_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
window.dimensions(i).window_reversal()
? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
kernel_spatial[i])
: kernel_spatial[i];
}
kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
b_.getInt64Ty());
TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
input_generator(input_index));
llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
b_.getInt64Ty());
TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
kernel_generator(kernel_index));
llvm::Value* product = FMul(input_value, kernel_value);
llvm::Value* sum = FAdd(Load(sum_address), FPCast(product, accumulator_type));
Store(sum, sum_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
return FPCast(Load(sum_address), lhs_llvm_type);
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) { Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
auto lhs = convolution->operand(0); auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1); auto rhs = convolution->operand(1);

View File

@ -121,13 +121,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emit an LLVM global variable for every constant buffer allocation. // Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals(); Status EmitConstantGlobals();
// Emit code to emit the element at `index` for a convolution instruction.
StatusOr<llvm::Value*> EmitElementalConvolution(
const HloConvolutionInstruction* convolution,
const llvm_ir::ElementGenerator& input_generator,
const llvm_ir::ElementGenerator& kernel_generator,
const llvm_ir::IrArray::Index& index);
protected: protected:
// //
// The following methods implement the DfsHloVisitor interface. // The following methods implement the DfsHloVisitor interface.

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -2222,27 +2223,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
llvm::Value* current_accumulator = Load(accumulator_alloca); llvm::Value* current_accumulator = Load(accumulator_alloca);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
llvm::Value* next_accumulator; llvm::Value* next_accumulator =
if (primitive_util::IsComplexType(primitive_type)) { EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
llvm::Value* product_real =
FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
llvm::Value* product_imag =
FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
next_accumulator = InsertValue(
current_accumulator,
FAdd(EmitExtractReal(current_accumulator), product_real), {0});
next_accumulator = InsertValue(
next_accumulator,
FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
} else if (primitive_type == PRED) {
next_accumulator = Or(current_accumulator, And(lhs_value, rhs_value));
} else {
next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
}
Store(next_accumulator, accumulator_alloca); Store(next_accumulator, accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
@ -2551,6 +2533,28 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
return complex; return complex;
} }
llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* accumulator,
xla::PrimitiveType primitive_type) {
if (primitive_util::IsComplexType(primitive_type)) {
llvm::Value* product_real =
FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
llvm::Value* product_imag =
FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
llvm::Value* next_accumulator = InsertValue(
accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
return InsertValue(next_accumulator,
FAdd(EmitExtractImag(accumulator), product_imag), {1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
} else if (primitive_type == PRED) {
return Or(accumulator, And(lhs, rhs));
}
return Add(accumulator, Mul(lhs, rhs));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap( StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
const HloMapInstruction* map_instr, const HloMapInstruction* map_instr,
absl::Span<llvm::Value* const> elemental_operands) { absl::Span<llvm::Value* const> elemental_operands) {
@ -2767,10 +2771,149 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
} }
StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution( StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
const HloInstruction* hlo, const HloInstruction* convolution,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) { const llvm_ir::IrArray::Index& index) {
return Unimplemented("Elemental convolution is not implemented"); const HloInstruction* lhs = convolution->operand(0);
const auto& input_generator = operand_to_generator.at(lhs);
const HloInstruction* rhs = convolution->operand(1);
const auto& kernel_generator = operand_to_generator.at(rhs);
const Window& window = convolution->window();
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
int num_spatial_dims = dnums.output_spatial_dimensions_size();
std::vector<llvm::Value*> output_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
}
llvm::Value* output_feature = index[dnums.output_feature_dimension()];
llvm::Value* batch = index[dnums.output_batch_dimension()];
// We will accumulate the products into this sum to calculate the output entry
// at the given index.
PrimitiveType lhs_element_type = lhs->shape().element_type();
llvm::Type* lhs_llvm_type =
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
// Upcast the accumulator to F32 from F16 for increased precision.
llvm::Type* accumulator_type =
lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
accumulator_type, "convolution_sum_address", b_);
llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
Store(constant_zero, sum_address);
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_spatial[i] =
loops
.AddLoop(
0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
absl::StrCat("k", i))
->GetIndVarValue();
}
llvm::Value* input_feature =
loops
.AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
"iz")
->GetIndVarValue();
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
// Calculate the spatial index in the input array, taking striding, dilation
// and padding into account. An index in the padding will be out of the bounds
// of the array.
const auto calculate_input_index = [this](llvm::Value* output_index,
llvm::Value* kernel_index,
const WindowDimension& window_dim) {
llvm::Value* strided_index =
NSWMul(output_index, b_->getInt64(window_dim.stride()));
llvm::Value* dilated_kernel_index =
NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
b_->getInt64(window_dim.padding_low()));
};
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] = calculate_input_index(
output_spatial[i], kernel_spatial[i], window.dimensions(i));
}
// We need to check if 0 <= input dim < bound, as otherwise we are in the
// padding so that we can skip the computation. That is equivalent to input
// dim < bound as an *unsigned* comparison, since a negative value will wrap
// to a large positive value. The input dim is dilated, so we need to dilate
// the bound as well to match.
// Also need to check that the input coordinates are not in one of the
// holes created by base dilation.
const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
return ICmpEQ(remainder, b_->getInt64(0));
};
llvm::Value* in_bounds_condition = b_->getInt1(true);
for (int i = 0; i < num_spatial_dims; ++i) {
llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
window.dimensions(i).base_dilation()));
llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
llvm::Value* dim_not_in_hole =
not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
in_bounds_condition = And(in_bounds_condition, dim_ok);
}
// Now we need to map the dilated base coordinates back to the actual
// data indices on the lhs.
const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
return SDiv(input_index, b_->getInt64(base_dilation));
};
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] =
undilate(input_spatial[i], window.dimensions(i).base_dilation());
}
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
SetToFirstInsertPoint(if_data.true_block, b_);
// We are not in the padding, so carry out the computation.
int num_dims = num_spatial_dims + 2;
std::vector<llvm::Value*> input_multi_index(num_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
}
input_multi_index[dnums.input_feature_dimension()] = input_feature;
input_multi_index[dnums.input_batch_dimension()] = batch;
std::vector<llvm::Value*> kernel_multi_index(num_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
window.dimensions(i).window_reversal()
? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
kernel_spatial[i])
: kernel_spatial[i];
}
kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
b_->getInt64Ty());
TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
input_generator(input_index));
llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
b_->getInt64Ty());
TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
kernel_generator(kernel_index));
llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address),
convolution->shape().element_type());
Store(sum, sum_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
return FPCast(Load(sum_address), lhs_llvm_type);
} }
// Evaluate polynomial using Horner's method. // Evaluate polynomial using Horner's method.

View File

@ -183,6 +183,11 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
llvm::Value* imag); llvm::Value* imag);
// Emit `accumulator + lhs * rhs` for the given primitive type.
llvm::Value* EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* accumulator,
xla::PrimitiveType primitive_type);
// Identifier of the thread unique among all threads on the device // Identifier of the thread unique among all threads on the device
virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }

View File

@ -1644,6 +1644,18 @@ ENTRY Test {
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
} }
XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU(ConvolveC64Forward)) {
constexpr char kHlo[] = R"(
HloModule TestModule
ENTRY Test {
%arg0 = c64[3,56,56,16] parameter(0)
%arg1 = c64[3,3,3,64] parameter(1)
ROOT %conv = c64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
})";
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}
XLA_TEST_F(ConvolutionHloTest, XLA_TEST_F(ConvolutionHloTest,
DISABLED_ON_GPU_ROCM(ConvolveF32ForwardReversed)) { DISABLED_ON_GPU_ROCM(ConvolveF32ForwardReversed)) {
constexpr char kHlo[] = R"( constexpr char kHlo[] = R"(