[XLA:CPU] When emitting an elemental F16 conv, do the accumulation in F32

This matches what cuBlas or Eigen are doing and gives better precision for F16
convolutions.

PiperOrigin-RevId: 259403856
This commit is contained in:
Benjamin Kramer 2019-07-22 14:14:08 -07:00 committed by TensorFlower Gardener
parent e5b12c6ce3
commit e547d262a5
2 changed files with 7 additions and 8 deletions

View File

@ -1027,10 +1027,13 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
PrimitiveType lhs_element_type = lhs->shape().element_type(); PrimitiveType lhs_element_type = lhs->shape().element_type();
llvm::Type* lhs_llvm_type = llvm::Type* lhs_llvm_type =
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
// Upcast the accumulator to F32 from F16 for increased precision.
llvm::Type* accumulator_type =
lhs_element_type == F16 ? b_.getFloatTy() : lhs_llvm_type;
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
lhs_llvm_type, "convolution_sum_address", &b_, accumulator_type, "convolution_sum_address", &b_,
MinimumAlignmentForPrimitiveType(lhs_element_type)); MinimumAlignmentForPrimitiveType(lhs_element_type));
llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
Store(constant_zero, sum_address); Store(constant_zero, sum_address);
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
@ -1139,11 +1142,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitElementalConvolution(
TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value, TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
kernel_generator(kernel_index)); kernel_generator(kernel_index));
llvm::Value* product = FMul(input_value, kernel_value); llvm::Value* product = FMul(input_value, kernel_value);
llvm::Value* sum = FAdd(Load(sum_address), product); llvm::Value* sum = FAdd(Load(sum_address), FPCast(product, accumulator_type));
Store(sum, sum_address); Store(sum, sum_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
return Load(sum_address); return FPCast(Load(sum_address), lhs_llvm_type);
} }
Status IrEmitter::HandleConvolution(HloInstruction* convolution) { Status IrEmitter::HandleConvolution(HloInstruction* convolution) {

View File

@ -1842,15 +1842,11 @@ INSTANTIATE_TEST_CASE_P(
Convolve1DTestParam{130, 1, 1, 1, 3}, Convolve1DTestParam{130, 1, 1, 1, 3},
Convolve1DTestParam{64, 1, 1, 1, 1}, Convolve1DTestParam{64, 1, 1, 1, 1},
Convolve1DTestParam{128, 1, 1, 1, 1}, Convolve1DTestParam{128, 1, 1, 1, 1},
// TODO(b/72566306): The following five tests failed on CPU with unreasonable
// relative errors. Last ran on 2018-02-22.
#if XLA_TEST_BACKEND_GPU
Convolve1DTestParam{139, 1, 1, 128, 1}, Convolve1DTestParam{139, 1, 1, 128, 1},
Convolve1DTestParam{640, 3, 3, 128, 1}, Convolve1DTestParam{640, 3, 3, 128, 1},
Convolve1DTestParam{900, 1, 1, 10, 1}, Convolve1DTestParam{900, 1, 1, 10, 1},
Convolve1DTestParam{1, 10, 10, 1, 10}, Convolve1DTestParam{1, 10, 10, 1, 10},
Convolve1DTestParam{1, 10, 130, 1, 1}, Convolve1DTestParam{1, 10, 130, 1, 1},
#endif
Convolve1DTestParam{1, 10, 130, 1, 2}, Convolve1DTestParam{1, 10, 130, 1, 2},
Convolve1DTestParam{1, 64, 64, 1, 10}, Convolve1DTestParam{1, 64, 64, 1, 10},
Convolve1DTestParam{1, 65, 65, 1, 1}, Convolve1DTestParam{1, 65, 65, 1, 1},