[XLA:CPU] When emitting an elemental F16 conv, do the accumulation in F32
This matches what cuBlas or Eigen are doing and gives better precision for F16 convolutions. PiperOrigin-RevId: 259403856
This commit is contained in:
parent
e5b12c6ce3
commit
e547d262a5
@ -1027,10 +1027,13 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
|
|||||||
PrimitiveType lhs_element_type = lhs->shape().element_type();
|
PrimitiveType lhs_element_type = lhs->shape().element_type();
|
||||||
llvm::Type* lhs_llvm_type =
|
llvm::Type* lhs_llvm_type =
|
||||||
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
|
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
|
||||||
|
// Upcast the accumulator to F32 from F16 for increased precision.
|
||||||
|
llvm::Type* accumulator_type =
|
||||||
|
lhs_element_type == F16 ? b_.getFloatTy() : lhs_llvm_type;
|
||||||
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
lhs_llvm_type, "convolution_sum_address", &b_,
|
accumulator_type, "convolution_sum_address", &b_,
|
||||||
MinimumAlignmentForPrimitiveType(lhs_element_type));
|
MinimumAlignmentForPrimitiveType(lhs_element_type));
|
||||||
llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
|
llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
|
||||||
Store(constant_zero, sum_address);
|
Store(constant_zero, sum_address);
|
||||||
|
|
||||||
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
|
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
|
||||||
@ -1139,11 +1142,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
|
|||||||
TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
|
TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
|
||||||
kernel_generator(kernel_index));
|
kernel_generator(kernel_index));
|
||||||
llvm::Value* product = FMul(input_value, kernel_value);
|
llvm::Value* product = FMul(input_value, kernel_value);
|
||||||
llvm::Value* sum = FAdd(Load(sum_address), product);
|
llvm::Value* sum = FAdd(Load(sum_address), FPCast(product, accumulator_type));
|
||||||
Store(sum, sum_address);
|
Store(sum, sum_address);
|
||||||
|
|
||||||
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
|
||||||
return Load(sum_address);
|
return FPCast(Load(sum_address), lhs_llvm_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||||
|
@ -1842,15 +1842,11 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
Convolve1DTestParam{130, 1, 1, 1, 3},
|
Convolve1DTestParam{130, 1, 1, 1, 3},
|
||||||
Convolve1DTestParam{64, 1, 1, 1, 1},
|
Convolve1DTestParam{64, 1, 1, 1, 1},
|
||||||
Convolve1DTestParam{128, 1, 1, 1, 1},
|
Convolve1DTestParam{128, 1, 1, 1, 1},
|
||||||
// TODO(b/72566306): The following five tests failed on CPU with unreasonable
|
|
||||||
// relative errors. Last ran on 2018-02-22.
|
|
||||||
#if XLA_TEST_BACKEND_GPU
|
|
||||||
Convolve1DTestParam{139, 1, 1, 128, 1},
|
Convolve1DTestParam{139, 1, 1, 128, 1},
|
||||||
Convolve1DTestParam{640, 3, 3, 128, 1},
|
Convolve1DTestParam{640, 3, 3, 128, 1},
|
||||||
Convolve1DTestParam{900, 1, 1, 10, 1},
|
Convolve1DTestParam{900, 1, 1, 10, 1},
|
||||||
Convolve1DTestParam{1, 10, 10, 1, 10},
|
Convolve1DTestParam{1, 10, 10, 1, 10},
|
||||||
Convolve1DTestParam{1, 10, 130, 1, 1},
|
Convolve1DTestParam{1, 10, 130, 1, 1},
|
||||||
#endif
|
|
||||||
Convolve1DTestParam{1, 10, 130, 1, 2},
|
Convolve1DTestParam{1, 10, 130, 1, 2},
|
||||||
Convolve1DTestParam{1, 64, 64, 1, 10},
|
Convolve1DTestParam{1, 64, 64, 1, 10},
|
||||||
Convolve1DTestParam{1, 65, 65, 1, 1},
|
Convolve1DTestParam{1, 65, 65, 1, 1},
|
||||||
|
Loading…
Reference in New Issue
Block a user