From 15e0f259b119f53b90fc1698f7abd9b739df3b6c Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 13 Jan 2021 03:01:35 -0800 Subject: [PATCH] [XLA:CPU] Move elemental conv emitter to generic code path so it can share the multiply-add logic with dot This fixes some edge cases when it comes to complex numbers and also would allow using it for ints if we want that. GPU doesn't use this code path. PiperOrigin-RevId: 351549079 Change-Id: I7a2f9e9758e270b62814c7e3e0419342c5b58196 --- tensorflow/compiler/xla/service/BUILD | 2 +- .../xla/service/cpu/elemental_ir_emitter.cc | 10 - .../xla/service/cpu/elemental_ir_emitter.h | 4 - .../compiler/xla/service/cpu/ir_emitter.cc | 146 -------------- .../compiler/xla/service/cpu/ir_emitter.h | 7 - .../xla/service/elemental_ir_emitter.cc | 189 +++++++++++++++--- .../xla/service/elemental_ir_emitter.h | 5 + .../compiler/xla/tests/convolution_test.cc | 12 ++ 8 files changed, 184 insertions(+), 191 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index eade412ee24..2baa44c35f0 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4203,12 +4203,12 @@ cc_library( deps = [ ":hlo", ":hlo_casting_utils", - ":hlo_module_config", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index b15aa3689b7..a4566b11a78 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -105,15 +105,5 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return result; } -StatusOr CpuElementalIrEmitter::EmitConvolution( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) { - return ir_emitter_->EmitElementalConvolution( - Cast(hlo), - operand_to_generator.at(hlo->operand(0)), - operand_to_generator.at(hlo->operand(1)), index); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index fbf582d3a8b..a002df25493 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -40,10 +40,6 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { llvm::Value* rhs) override; StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; - StatusOr EmitConvolution( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) override; StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index e1765f41518..7179c5a00ba 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -856,152 +856,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { hlo_module_config_, target_machine_features_); } -StatusOr IrEmitter::EmitElementalConvolution( - const HloConvolutionInstruction* convolution, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::ElementGenerator& kernel_generator, - const llvm_ir::IrArray::Index& index) { - const HloInstruction* lhs = convolution->operand(0); - const HloInstruction* rhs = convolution->operand(1); - const Window& window = convolution->window(); - - const ConvolutionDimensionNumbers& dnums = - convolution->convolution_dimension_numbers(); - int num_spatial_dims = dnums.output_spatial_dimensions_size(); - std::vector output_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; - } - llvm::Value* output_feature = index[dnums.output_feature_dimension()]; - llvm::Value* batch = index[dnums.output_batch_dimension()]; - - // We will accumulate the products into this sum to calculate the output entry - // at the given index. - PrimitiveType lhs_element_type = lhs->shape().element_type(); - llvm::Type* lhs_llvm_type = - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); - // Upcast the accumulator to F32 from F16 for increased precision. - llvm::Type* accumulator_type = - lhs_element_type == F16 ? b_.getFloatTy() : lhs_llvm_type; - llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - accumulator_type, "convolution_sum_address", &b_, - MinimumAlignmentForPrimitiveType(lhs_element_type)); - llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type); - Store(constant_zero, sum_address); - - llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); - std::vector kernel_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - kernel_spatial[i] = - loops - .AddLoop( - 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), - absl::StrCat("k", i)) - ->GetIndVarValue(); - } - llvm::Value* input_feature = - loops - .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()), - "iz") - ->GetIndVarValue(); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - // Calculate the spatial index in the input array, taking striding, dilation - // and padding into account. An index in the padding will be out of the bounds - // of the array. - const auto calculate_input_index = [this](llvm::Value* output_index, - llvm::Value* kernel_index, - const WindowDimension& window_dim) { - llvm::Value* strided_index = - NSWMul(output_index, b_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = - NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation())); - return NSWSub(NSWAdd(strided_index, dilated_kernel_index), - b_.getInt64(window_dim.padding_low())); - }; - std::vector input_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial[i] = calculate_input_index( - output_spatial[i], kernel_spatial[i], window.dimensions(i)); - } - - // We need to check if 0 <= input dim < bound, as otherwise we are in the - // padding so that we can skip the computation. That is equivalent to input - // dim < bound as an *unsigned* comparison, since a negative value will wrap - // to a large positive value. The input dim is dilated, so we need to dilate - // the bound as well to match. - - // Also need to check that the input coordinates are not in one of the - // holes created by base dilation. - const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { - llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation)); - return ICmpEQ(remainder, b_.getInt64(0)); - }; - - llvm::Value* in_bounds_condition = b_.getInt1(true); - for (int i = 0; i < num_spatial_dims; ++i) { - llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( - lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), - window.dimensions(i).base_dilation())); - llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound); - llvm::Value* dim_not_in_hole = - not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); - llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole); - in_bounds_condition = And(in_bounds_condition, dim_ok); - } - - // Now we need to map the dilated base coordinates back to the actual - // data indices on the lhs. - const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { - return SDiv(input_index, b_.getInt64(base_dilation)); - }; - for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial[i] = - undilate(input_spatial[i], window.dimensions(i).base_dilation()); - } - - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); - SetToFirstInsertPoint(if_data.true_block, &b_); - - // We are not in the padding, so carry out the computation. - int num_dims = num_spatial_dims + 2; - std::vector input_multi_index(num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; - } - input_multi_index[dnums.input_feature_dimension()] = input_feature; - input_multi_index[dnums.input_batch_dimension()] = batch; - - std::vector kernel_multi_index(num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - kernel_multi_index[dnums.kernel_spatial_dimensions(i)] = - window.dimensions(i).window_reversal() - ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), - kernel_spatial[i]) - : kernel_spatial[i]; - } - - kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature; - kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature; - - llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(), - b_.getInt64Ty()); - TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, - input_generator(input_index)); - llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(), - b_.getInt64Ty()); - TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value, - kernel_generator(kernel_index)); - llvm::Value* product = FMul(input_value, kernel_value); - llvm::Value* sum = FAdd(Load(sum_address), FPCast(product, accumulator_type)); - Store(sum, sum_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return FPCast(Load(sum_address), lhs_llvm_type); -} - Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 891d53c889d..49490ef0fe9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -121,13 +121,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); - // Emit code to emit the element at `index` for a convolution instruction. - StatusOr EmitElementalConvolution( - const HloConvolutionInstruction* convolution, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::ElementGenerator& kernel_generator, - const llvm_ir::IrArray::Index& index); - protected: // // The following methods implement the DfsHloVisitor interface. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 735b0b71818..817d3e6de6f 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/logging.h" @@ -2222,27 +2223,8 @@ StatusOr ElementalIrEmitter::EmitElementalDot( llvm::Value* current_accumulator = Load(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); - llvm::Value* next_accumulator; - if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = - FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), - FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); - llvm::Value* product_imag = - FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), - FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); - next_accumulator = InsertValue( - current_accumulator, - FAdd(EmitExtractReal(current_accumulator), product_real), {0}); - next_accumulator = InsertValue( - next_accumulator, - FAdd(EmitExtractImag(current_accumulator), product_imag), {1}); - } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value)); - } else if (primitive_type == PRED) { - next_accumulator = Or(current_accumulator, And(lhs_value, rhs_value)); - } else { - next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value)); - } + llvm::Value* next_accumulator = + EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type); Store(next_accumulator, accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); @@ -2551,6 +2533,28 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, return complex; } +llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* accumulator, + xla::PrimitiveType primitive_type) { + if (primitive_util::IsComplexType(primitive_type)) { + llvm::Value* product_real = + FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)), + FMul(EmitExtractImag(lhs), EmitExtractImag(rhs))); + llvm::Value* product_imag = + FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)), + FMul(EmitExtractImag(lhs), EmitExtractReal(rhs))); + llvm::Value* next_accumulator = InsertValue( + accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0}); + return InsertValue(next_accumulator, + FAdd(EmitExtractImag(accumulator), product_imag), {1}); + } else if (primitive_util::IsFloatingPointType(primitive_type)) { + return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType())); + } else if (primitive_type == PRED) { + return Or(accumulator, And(lhs, rhs)); + } + return Add(accumulator, Mul(lhs, rhs)); +} + StatusOr ElementalIrEmitter::EmitElementalMap( const HloMapInstruction* map_instr, absl::Span elemental_operands) { @@ -2767,10 +2771,149 @@ StatusOr ElementalIrEmitter::EmitElementalReduce( } StatusOr ElementalIrEmitter::EmitConvolution( - const HloInstruction* hlo, + const HloInstruction* convolution, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { - return Unimplemented("Elemental convolution is not implemented"); + const HloInstruction* lhs = convolution->operand(0); + const auto& input_generator = operand_to_generator.at(lhs); + const HloInstruction* rhs = convolution->operand(1); + const auto& kernel_generator = operand_to_generator.at(rhs); + const Window& window = convolution->window(); + + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + int num_spatial_dims = dnums.output_spatial_dimensions_size(); + std::vector output_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; + } + llvm::Value* output_feature = index[dnums.output_feature_dimension()]; + llvm::Value* batch = index[dnums.output_batch_dimension()]; + + // We will accumulate the products into this sum to calculate the output entry + // at the given index. + PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Type* lhs_llvm_type = + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); + // Upcast the accumulator to F32 from F16 for increased precision. + llvm::Type* accumulator_type = + lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type; + llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( + accumulator_type, "convolution_sum_address", b_); + llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type); + Store(constant_zero, sum_address); + + llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_); + std::vector kernel_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_spatial[i] = + loops + .AddLoop( + 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), + absl::StrCat("k", i)) + ->GetIndVarValue(); + } + llvm::Value* input_feature = + loops + .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()), + "iz") + ->GetIndVarValue(); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); + + // Calculate the spatial index in the input array, taking striding, dilation + // and padding into account. An index in the padding will be out of the bounds + // of the array. + const auto calculate_input_index = [this](llvm::Value* output_index, + llvm::Value* kernel_index, + const WindowDimension& window_dim) { + llvm::Value* strided_index = + NSWMul(output_index, b_->getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = + NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation())); + return NSWSub(NSWAdd(strided_index, dilated_kernel_index), + b_->getInt64(window_dim.padding_low())); + }; + std::vector input_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial[i] = calculate_input_index( + output_spatial[i], kernel_spatial[i], window.dimensions(i)); + } + + // We need to check if 0 <= input dim < bound, as otherwise we are in the + // padding so that we can skip the computation. That is equivalent to input + // dim < bound as an *unsigned* comparison, since a negative value will wrap + // to a large positive value. The input dim is dilated, so we need to dilate + // the bound as well to match. + + // Also need to check that the input coordinates are not in one of the + // holes created by base dilation. + const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { + llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation)); + return ICmpEQ(remainder, b_->getInt64(0)); + }; + + llvm::Value* in_bounds_condition = b_->getInt1(true); + for (int i = 0; i < num_spatial_dims; ++i) { + llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound( + lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), + window.dimensions(i).base_dilation())); + llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_not_in_hole = + not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); + llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole); + in_bounds_condition = And(in_bounds_condition, dim_ok); + } + + // Now we need to map the dilated base coordinates back to the actual + // data indices on the lhs. + const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { + return SDiv(input_index, b_->getInt64(base_dilation)); + }; + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial[i] = + undilate(input_spatial[i], window.dimensions(i).base_dilation()); + } + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); + + // We are not in the padding, so carry out the computation. + int num_dims = num_spatial_dims + 2; + std::vector input_multi_index(num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; + } + input_multi_index[dnums.input_feature_dimension()] = input_feature; + input_multi_index[dnums.input_batch_dimension()] = batch; + + std::vector kernel_multi_index(num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_multi_index[dnums.kernel_spatial_dimensions(i)] = + window.dimensions(i).window_reversal() + ? NSWSub(b_->getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) + : kernel_spatial[i]; + } + + kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature; + kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature; + + llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(), + b_->getInt64Ty()); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, + input_generator(input_index)); + llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(), + b_->getInt64Ty()); + TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value, + kernel_generator(kernel_index)); + llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address), + convolution->shape().element_type()); + Store(sum, sum_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); + return FPCast(Load(sum_address), lhs_llvm_type); } // Evaluate polynomial using Horner's method. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 60e25c7d8bf..5cf368f4812 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -183,6 +183,11 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, llvm::Value* imag); + // Emit `accumulator + lhs * rhs` for the given primitive type. + llvm::Value* EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* accumulator, + xla::PrimitiveType primitive_type); + // Identifier of the thread unique among all threads on the device virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 8337f93c3b4..2802c669688 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -1644,6 +1644,18 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU(ConvolveC64Forward)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = c64[3,56,56,16] parameter(0) + %arg1 = c64[3,3,3,64] parameter(1) + ROOT %conv = c64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} + XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF32ForwardReversed)) { constexpr char kHlo[] = R"(