[XLA:CPU] Move elemental conv emitter to generic code path so it can share the multiply-add logic with dot
This fixes some edge cases when it comes to complex numbers and also would allow using it for ints if we want that. GPU doesn't use this code path. PiperOrigin-RevId: 351549079 Change-Id: I7a2f9e9758e270b62814c7e3e0419342c5b58196
This commit is contained in:
parent
f2a0826b7d
commit
15e0f259b1
@ -4203,12 +4203,12 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_casting_utils",
|
":hlo_casting_utils",
|
||||||
":hlo_module_config",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla:window_util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
|
"//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
|
||||||
|
@ -105,15 +105,5 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution(
|
|
||||||
const HloInstruction* hlo,
|
|
||||||
const HloToElementGeneratorMap& operand_to_generator,
|
|
||||||
const llvm_ir::IrArray::Index& index) {
|
|
||||||
return ir_emitter_->EmitElementalConvolution(
|
|
||||||
Cast<HloConvolutionInstruction>(hlo),
|
|
||||||
operand_to_generator.at(hlo->operand(0)),
|
|
||||||
operand_to_generator.at(hlo->operand(1)), index);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -40,10 +40,6 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
|
|||||||
llvm::Value* rhs) override;
|
llvm::Value* rhs) override;
|
||||||
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||||
llvm::Value* value) override;
|
llvm::Value* value) override;
|
||||||
StatusOr<llvm::Value*> EmitConvolution(
|
|
||||||
const HloInstruction* hlo,
|
|
||||||
const HloToElementGeneratorMap& operand_to_generator,
|
|
||||||
const llvm_ir::IrArray::Index& index) override;
|
|
||||||
|
|
||||||
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
|
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
|
||||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||||
|
@ -856,152 +856,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
|
|||||||
hlo_module_config_, target_machine_features_);
|
hlo_module_config_, target_machine_features_);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
|
|
||||||
const HloConvolutionInstruction* convolution,
|
|
||||||
const llvm_ir::ElementGenerator& input_generator,
|
|
||||||
const llvm_ir::ElementGenerator& kernel_generator,
|
|
||||||
const llvm_ir::IrArray::Index& index) {
|
|
||||||
const HloInstruction* lhs = convolution->operand(0);
|
|
||||||
const HloInstruction* rhs = convolution->operand(1);
|
|
||||||
const Window& window = convolution->window();
|
|
||||||
|
|
||||||
const ConvolutionDimensionNumbers& dnums =
|
|
||||||
convolution->convolution_dimension_numbers();
|
|
||||||
int num_spatial_dims = dnums.output_spatial_dimensions_size();
|
|
||||||
std::vector<llvm::Value*> output_spatial(num_spatial_dims);
|
|
||||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
|
||||||
output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
|
|
||||||
}
|
|
||||||
llvm::Value* output_feature = index[dnums.output_feature_dimension()];
|
|
||||||
llvm::Value* batch = index[dnums.output_batch_dimension()];
|
|
||||||
|
|
||||||
// We will accumulate the products into this sum to calculate the output entry
|
|
||||||
// at the given index.
|
|
||||||
PrimitiveType lhs_element_type = lhs->shape().element_type();
|
|
||||||
llvm::Type* lhs_llvm_type =
|
|
||||||
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
|
|
||||||
// Upcast the accumulator to F32 from F16 for increased precision.
|
|
||||||
llvm::Type* accumulator_type =
|
|
||||||
lhs_element_type == F16 ? b_.getFloatTy() : lhs_llvm_type;
|
|
||||||
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
|
||||||
accumulator_type, "convolution_sum_address", &b_,
|
|
||||||
MinimumAlignmentForPrimitiveType(lhs_element_type));
|
|
||||||
llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
|
|
||||||
Store(constant_zero, sum_address);
|
|
||||||
|
|
||||||
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
|
|
||||||
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
|
|
||||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
|
||||||
kernel_spatial[i] =
|
|
||||||
loops
|
|
||||||
.AddLoop(
|
|
||||||
0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
|
|
||||||
absl::StrCat("k", i))
|
|
||||||
->GetIndVarValue();
|
|
||||||
}
|
|
||||||
llvm::Value* input_feature =
|
|
||||||
loops
|
|
||||||
.AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
|
|
||||||
"iz")
|
|
||||||
->GetIndVarValue();
|
|
||||||
|
|
||||||
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
|
|
||||||
|
|
||||||
// Calculate the spatial index in the input array, taking striding, dilation
|
|
||||||
// and padding into account. An index in the padding will be out of the bounds
|
|
||||||
// of the array.
|
|
||||||
const auto calculate_input_index = [this](llvm::Value* output_index,
|
|
||||||
llvm::Value* kernel_index,
|
|
||||||
const WindowDimension& window_dim) {
|
|
||||||
llvm::Value* strided_index =
|
|
||||||
NSWMul(output_index, b_.getInt64(window_dim.stride()));
|
|
||||||
llvm::Value* dilated_kernel_index =
|
|
||||||
NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
|
|
||||||
return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
|
|
||||||
b_.getInt64(window_dim.padding_low()));
|
|
||||||
};
|
|
||||||
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
|
|
||||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
|
||||||
input_spatial[i] = calculate_input_index(
|
|
||||||
output_spatial[i], kernel_spatial[i], window.dimensions(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to check if 0 <= input dim < bound, as otherwise we are in the
|
|
||||||
// padding so that we can skip the computation. That is equivalent to input
|
|
||||||
// dim < bound as an *unsigned* comparison, since a negative value will wrap
|
|
||||||
// to a large positive value. The input dim is dilated, so we need to dilate
|
|
||||||
// the bound as well to match.
|
|
||||||
|
|
||||||
// Also need to check that the input coordinates are not in one of the
|
|
||||||
// holes created by base dilation.
|
|
||||||
const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
|
|
||||||
llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
|
|
||||||
return ICmpEQ(remainder, b_.getInt64(0));
|
|
||||||
};
|
|
||||||
|
|
||||||
llvm::Value* in_bounds_condition = b_.getInt1(true);
|
|
||||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
|
||||||
llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
|
|
||||||
lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
|
|
||||||
window.dimensions(i).base_dilation()));
|
|
||||||
llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
|
|
||||||
llvm::Value* dim_not_in_hole =
|
|
||||||
not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
|
|
||||||
llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
|
|
||||||
in_bounds_condition = And(in_bounds_condition, dim_ok);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now we need to map the dilated base coordinates back to the actual
|
|
||||||
// data indices on the lhs.
|
|
||||||
const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
|
|
||||||
return SDiv(input_index, b_.getInt64(base_dilation));
|
|
||||||
};
|
|
||||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
|
||||||
input_spatial[i] =
|
|
||||||
undilate(input_spatial[i], window.dimensions(i).base_dilation());
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm_ir::LlvmIfData if_data =
|
|
||||||
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
|
|
||||||
SetToFirstInsertPoint(if_data.true_block, &b_);
|
|
||||||
|
|
||||||
// We are not in the padding, so carry out the computation.
|
|
||||||
int num_dims = num_spatial_dims + 2;
|
|
||||||
std::vector<llvm::Value*> input_multi_index(num_dims);
|
|
||||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
|
||||||
input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
|
|
||||||
}
|
|
||||||
input_multi_index[dnums.input_feature_dimension()] = input_feature;
|
|
||||||
input_multi_index[dnums.input_batch_dimension()] = batch;
|
|
||||||
|
|
||||||
std::vector<llvm::Value*> kernel_multi_index(num_dims);
|
|
||||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
|
||||||
kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
|
|
||||||
window.dimensions(i).window_reversal()
|
|
||||||
? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
|
|
||||||
kernel_spatial[i])
|
|
||||||
: kernel_spatial[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
|
|
||||||
kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
|
|
||||||
|
|
||||||
llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
|
|
||||||
b_.getInt64Ty());
|
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
|
|
||||||
input_generator(input_index));
|
|
||||||
llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
|
|
||||||
b_.getInt64Ty());
|
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
|
|
||||||
kernel_generator(kernel_index));
|
|
||||||
llvm::Value* product = FMul(input_value, kernel_value);
|
|
||||||
llvm::Value* sum = FAdd(Load(sum_address), FPCast(product, accumulator_type));
|
|
||||||
Store(sum, sum_address);
|
|
||||||
|
|
||||||
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
|
||||||
return FPCast(Load(sum_address), lhs_llvm_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||||
auto lhs = convolution->operand(0);
|
auto lhs = convolution->operand(0);
|
||||||
auto rhs = convolution->operand(1);
|
auto rhs = convolution->operand(1);
|
||||||
|
@ -121,13 +121,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
// Emit an LLVM global variable for every constant buffer allocation.
|
// Emit an LLVM global variable for every constant buffer allocation.
|
||||||
Status EmitConstantGlobals();
|
Status EmitConstantGlobals();
|
||||||
|
|
||||||
// Emit code to emit the element at `index` for a convolution instruction.
|
|
||||||
StatusOr<llvm::Value*> EmitElementalConvolution(
|
|
||||||
const HloConvolutionInstruction* convolution,
|
|
||||||
const llvm_ir::ElementGenerator& input_generator,
|
|
||||||
const llvm_ir::ElementGenerator& kernel_generator,
|
|
||||||
const llvm_ir::IrArray::Index& index);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
//
|
//
|
||||||
// The following methods implement the DfsHloVisitor interface.
|
// The following methods implement the DfsHloVisitor interface.
|
||||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/compiler/xla/window_util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -2222,27 +2223,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
|
|||||||
llvm::Value* current_accumulator = Load(accumulator_alloca);
|
llvm::Value* current_accumulator = Load(accumulator_alloca);
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
|
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
|
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
|
||||||
llvm::Value* next_accumulator;
|
llvm::Value* next_accumulator =
|
||||||
if (primitive_util::IsComplexType(primitive_type)) {
|
EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
|
||||||
llvm::Value* product_real =
|
|
||||||
FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
|
|
||||||
FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
|
|
||||||
llvm::Value* product_imag =
|
|
||||||
FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
|
|
||||||
FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
|
|
||||||
next_accumulator = InsertValue(
|
|
||||||
current_accumulator,
|
|
||||||
FAdd(EmitExtractReal(current_accumulator), product_real), {0});
|
|
||||||
next_accumulator = InsertValue(
|
|
||||||
next_accumulator,
|
|
||||||
FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
|
|
||||||
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
|
|
||||||
next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
|
|
||||||
} else if (primitive_type == PRED) {
|
|
||||||
next_accumulator = Or(current_accumulator, And(lhs_value, rhs_value));
|
|
||||||
} else {
|
|
||||||
next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
|
|
||||||
}
|
|
||||||
Store(next_accumulator, accumulator_alloca);
|
Store(next_accumulator, accumulator_alloca);
|
||||||
|
|
||||||
SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
|
SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
|
||||||
@ -2551,6 +2533,28 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
|
|||||||
return complex;
|
return complex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
|
||||||
|
llvm::Value* accumulator,
|
||||||
|
xla::PrimitiveType primitive_type) {
|
||||||
|
if (primitive_util::IsComplexType(primitive_type)) {
|
||||||
|
llvm::Value* product_real =
|
||||||
|
FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
|
||||||
|
FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
|
||||||
|
llvm::Value* product_imag =
|
||||||
|
FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
|
||||||
|
FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
|
||||||
|
llvm::Value* next_accumulator = InsertValue(
|
||||||
|
accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
|
||||||
|
return InsertValue(next_accumulator,
|
||||||
|
FAdd(EmitExtractImag(accumulator), product_imag), {1});
|
||||||
|
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
|
||||||
|
return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
|
||||||
|
} else if (primitive_type == PRED) {
|
||||||
|
return Or(accumulator, And(lhs, rhs));
|
||||||
|
}
|
||||||
|
return Add(accumulator, Mul(lhs, rhs));
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
|
||||||
const HloMapInstruction* map_instr,
|
const HloMapInstruction* map_instr,
|
||||||
absl::Span<llvm::Value* const> elemental_operands) {
|
absl::Span<llvm::Value* const> elemental_operands) {
|
||||||
@ -2767,10 +2771,149 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
|
||||||
const HloInstruction* hlo,
|
const HloInstruction* convolution,
|
||||||
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
|
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
|
||||||
const llvm_ir::IrArray::Index& index) {
|
const llvm_ir::IrArray::Index& index) {
|
||||||
return Unimplemented("Elemental convolution is not implemented");
|
const HloInstruction* lhs = convolution->operand(0);
|
||||||
|
const auto& input_generator = operand_to_generator.at(lhs);
|
||||||
|
const HloInstruction* rhs = convolution->operand(1);
|
||||||
|
const auto& kernel_generator = operand_to_generator.at(rhs);
|
||||||
|
const Window& window = convolution->window();
|
||||||
|
|
||||||
|
const ConvolutionDimensionNumbers& dnums =
|
||||||
|
convolution->convolution_dimension_numbers();
|
||||||
|
int num_spatial_dims = dnums.output_spatial_dimensions_size();
|
||||||
|
std::vector<llvm::Value*> output_spatial(num_spatial_dims);
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
|
||||||
|
}
|
||||||
|
llvm::Value* output_feature = index[dnums.output_feature_dimension()];
|
||||||
|
llvm::Value* batch = index[dnums.output_batch_dimension()];
|
||||||
|
|
||||||
|
// We will accumulate the products into this sum to calculate the output entry
|
||||||
|
// at the given index.
|
||||||
|
PrimitiveType lhs_element_type = lhs->shape().element_type();
|
||||||
|
llvm::Type* lhs_llvm_type =
|
||||||
|
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
|
||||||
|
// Upcast the accumulator to F32 from F16 for increased precision.
|
||||||
|
llvm::Type* accumulator_type =
|
||||||
|
lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
|
||||||
|
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
|
accumulator_type, "convolution_sum_address", b_);
|
||||||
|
llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
|
||||||
|
Store(constant_zero, sum_address);
|
||||||
|
|
||||||
|
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
|
||||||
|
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
kernel_spatial[i] =
|
||||||
|
loops
|
||||||
|
.AddLoop(
|
||||||
|
0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
|
||||||
|
absl::StrCat("k", i))
|
||||||
|
->GetIndVarValue();
|
||||||
|
}
|
||||||
|
llvm::Value* input_feature =
|
||||||
|
loops
|
||||||
|
.AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
|
||||||
|
"iz")
|
||||||
|
->GetIndVarValue();
|
||||||
|
|
||||||
|
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
|
||||||
|
|
||||||
|
// Calculate the spatial index in the input array, taking striding, dilation
|
||||||
|
// and padding into account. An index in the padding will be out of the bounds
|
||||||
|
// of the array.
|
||||||
|
const auto calculate_input_index = [this](llvm::Value* output_index,
|
||||||
|
llvm::Value* kernel_index,
|
||||||
|
const WindowDimension& window_dim) {
|
||||||
|
llvm::Value* strided_index =
|
||||||
|
NSWMul(output_index, b_->getInt64(window_dim.stride()));
|
||||||
|
llvm::Value* dilated_kernel_index =
|
||||||
|
NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
|
||||||
|
return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
|
||||||
|
b_->getInt64(window_dim.padding_low()));
|
||||||
|
};
|
||||||
|
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
input_spatial[i] = calculate_input_index(
|
||||||
|
output_spatial[i], kernel_spatial[i], window.dimensions(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to check if 0 <= input dim < bound, as otherwise we are in the
|
||||||
|
// padding so that we can skip the computation. That is equivalent to input
|
||||||
|
// dim < bound as an *unsigned* comparison, since a negative value will wrap
|
||||||
|
// to a large positive value. The input dim is dilated, so we need to dilate
|
||||||
|
// the bound as well to match.
|
||||||
|
|
||||||
|
// Also need to check that the input coordinates are not in one of the
|
||||||
|
// holes created by base dilation.
|
||||||
|
const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
|
||||||
|
llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
|
||||||
|
return ICmpEQ(remainder, b_->getInt64(0));
|
||||||
|
};
|
||||||
|
|
||||||
|
llvm::Value* in_bounds_condition = b_->getInt1(true);
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
|
||||||
|
lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
|
||||||
|
window.dimensions(i).base_dilation()));
|
||||||
|
llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
|
||||||
|
llvm::Value* dim_not_in_hole =
|
||||||
|
not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
|
||||||
|
llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
|
||||||
|
in_bounds_condition = And(in_bounds_condition, dim_ok);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to map the dilated base coordinates back to the actual
|
||||||
|
// data indices on the lhs.
|
||||||
|
const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
|
||||||
|
return SDiv(input_index, b_->getInt64(base_dilation));
|
||||||
|
};
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
input_spatial[i] =
|
||||||
|
undilate(input_spatial[i], window.dimensions(i).base_dilation());
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm_ir::LlvmIfData if_data =
|
||||||
|
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
|
||||||
|
SetToFirstInsertPoint(if_data.true_block, b_);
|
||||||
|
|
||||||
|
// We are not in the padding, so carry out the computation.
|
||||||
|
int num_dims = num_spatial_dims + 2;
|
||||||
|
std::vector<llvm::Value*> input_multi_index(num_dims);
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
|
||||||
|
}
|
||||||
|
input_multi_index[dnums.input_feature_dimension()] = input_feature;
|
||||||
|
input_multi_index[dnums.input_batch_dimension()] = batch;
|
||||||
|
|
||||||
|
std::vector<llvm::Value*> kernel_multi_index(num_dims);
|
||||||
|
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||||
|
kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
|
||||||
|
window.dimensions(i).window_reversal()
|
||||||
|
? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
|
||||||
|
kernel_spatial[i])
|
||||||
|
: kernel_spatial[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
|
||||||
|
kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
|
||||||
|
|
||||||
|
llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
|
||||||
|
b_->getInt64Ty());
|
||||||
|
TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
|
||||||
|
input_generator(input_index));
|
||||||
|
llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
|
||||||
|
b_->getInt64Ty());
|
||||||
|
TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
|
||||||
|
kernel_generator(kernel_index));
|
||||||
|
llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address),
|
||||||
|
convolution->shape().element_type());
|
||||||
|
Store(sum, sum_address);
|
||||||
|
|
||||||
|
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
|
||||||
|
return FPCast(Load(sum_address), lhs_llvm_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evaluate polynomial using Horner's method.
|
// Evaluate polynomial using Horner's method.
|
||||||
|
@ -183,6 +183,11 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||||||
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
|
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
|
||||||
llvm::Value* imag);
|
llvm::Value* imag);
|
||||||
|
|
||||||
|
// Emit `accumulator + lhs * rhs` for the given primitive type.
|
||||||
|
llvm::Value* EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
|
||||||
|
llvm::Value* accumulator,
|
||||||
|
xla::PrimitiveType primitive_type);
|
||||||
|
|
||||||
// Identifier of the thread unique among all threads on the device
|
// Identifier of the thread unique among all threads on the device
|
||||||
virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
|
virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
|
||||||
|
|
||||||
|
@ -1644,6 +1644,18 @@ ENTRY Test {
|
|||||||
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
|
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU(ConvolveC64Forward)) {
|
||||||
|
constexpr char kHlo[] = R"(
|
||||||
|
HloModule TestModule
|
||||||
|
|
||||||
|
ENTRY Test {
|
||||||
|
%arg0 = c64[3,56,56,16] parameter(0)
|
||||||
|
%arg1 = c64[3,3,3,64] parameter(1)
|
||||||
|
ROOT %conv = c64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf
|
||||||
|
})";
|
||||||
|
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ConvolutionHloTest,
|
XLA_TEST_F(ConvolutionHloTest,
|
||||||
DISABLED_ON_GPU_ROCM(ConvolveF32ForwardReversed)) {
|
DISABLED_ON_GPU_ROCM(ConvolveF32ForwardReversed)) {
|
||||||
constexpr char kHlo[] = R"(
|
constexpr char kHlo[] = R"(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user