[XLA] Add Cbrt operation (lowered to pow(1.0/3.0))

First phase of my onboarding bug. Still need to improve internal calculation of cbrt.

PiperOrigin-RevId: 309533789
Change-Id: I75058b7f319e32e51e85324ca51515ef802cc111
This commit is contained in:
A. Unique TensorFlower 2020-05-02 00:24:51 -07:00 committed by TensorFlower Gardener
parent 8cb03a2837
commit c75e73c2a5
15 changed files with 113 additions and 0 deletions

View File

@ -298,6 +298,15 @@ XLA_TEST_F(MathTest, SqrtSixValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, CbrtSixValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {8.0, 1.0, 4096.0, -64.0, 1.728, 1331});
Cbrt(x);
std::vector<float> expected = {2, 1, 16, -4, 1.2, 11};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.001));
}
XLA_TEST_F(MathTest, SinhSmallValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});

View File

@ -3571,6 +3571,9 @@ XlaOp Imag(const XlaOp operand) {
XlaOp Sqrt(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
}
XlaOp Cbrt(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
}
XlaOp Rsqrt(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
}

View File

@ -1030,6 +1030,7 @@ class XlaBuilder {
friend XlaOp Imag(XlaOp operand);
friend XlaOp Sqrt(XlaOp operand);
friend XlaOp Rsqrt(XlaOp operand);
friend XlaOp Cbrt(XlaOp operand);
friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp IsFinite(XlaOp operand);
@ -1884,6 +1885,9 @@ XlaOp Imag(XlaOp operand);
// Enqueues a sqrt computation onto the computation.
XlaOp Sqrt(XlaOp operand);
// Enqueues a cbrt computation onto the computation.
XlaOp Cbrt(XlaOp operand);
// Enqueues a rsqrt computation onto the computation.
XlaOp Rsqrt(XlaOp operand);

View File

@ -109,6 +109,9 @@ class DfsHloVisitorBase {
virtual Status HandleRsqrt(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
virtual Status HandleCbrt(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;

View File

@ -461,6 +461,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
return EmitSqrt(op->shape().element_type(), operand_value);
case HloOpcode::kRsqrt:
return EmitRsqrt(op->shape().element_type(), operand_value);
case HloOpcode::kCbrt:
return EmitCbrt(op->shape().element_type(), operand_value);
case HloOpcode::kFloor:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
{operand_value},
@ -787,6 +789,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
case HloOpcode::kRsqrt: {
return EmitComplexRsqrt(op, component_type, operand_value);
}
case HloOpcode::kCbrt: {
return EmitComplexCbrt(op, component_type, operand_value);
}
case HloOpcode::kNegate:
return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
FNeg(EmitExtractImag(operand_value)));
@ -1081,6 +1086,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
return EmitComposeComplex(op, real_part, imag_part);
}
//
// Using EmitComplexPower with c=1.0/3.0 and d=0
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
const HloInstruction* op, PrimitiveType prim_type,
llvm::Value* operand_value) {
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
auto zero = llvm::ConstantFP::get(type, 0);
llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
return EmitComplexPower(op, a, b, third, zero);
}
// (a+bi)^(c+di) =
// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
@ -1392,6 +1410,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
{lhs->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
llvm::Value* value) {
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
auto abs_value =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
EmitPow(prim_type, abs_value, third));
auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
{abs_res, value}, {type}, b_);
return signed_res;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
@ -2181,6 +2212,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {

View File

@ -116,6 +116,9 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
virtual StatusOr<llvm::Value*> EmitSqrt(PrimitiveType prim_type,
llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitCbrt(PrimitiveType prim_type,
llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitRsqrt(PrimitiveType prim_type,
llvm::Value* value);
@ -159,6 +162,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
PrimitiveType prim_type,
llvm::Value* operand_value);
virtual StatusOr<llvm::Value*> EmitComplexCbrt(const HloInstruction* op,
PrimitiveType prim_type,
llvm::Value* operand_value);
virtual StatusOr<llvm::Value*> EmitComplexRsqrt(const HloInstruction* op,
PrimitiveType prim_type,
llvm::Value* operand_value);

View File

@ -700,6 +700,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleCbrt(HloInstruction* cbrt) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[cbrt],
ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) -> ElementwiseT {
return std::pow(elem_operand, static_cast<ElementwiseT>(1.0 / 3.0));
return elem_operand.real() < 0
? -std::pow(-elem_operand,
static_cast<ElementwiseT>(1.0 / 3.0))
: std::pow(elem_operand,
static_cast<ElementwiseT>(1.0 / 3.0));
}));
return Status::OK();
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleCbrt(HloInstruction* cbrt) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[cbrt],
ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) {
return std::cbrt(elem_operand);
}));
return Status::OK();
}
Status HandleCbrt(HloInstruction* cbrt) override {
return HandleCbrt<ElementwiseT>(cbrt);
}
Status HandleRsqrt(HloInstruction* rsqrt) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[rsqrt],

View File

@ -980,6 +980,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kSlice:
case HloOpcode::kSort:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
// De-emphasize scalar-shaped elementwise ops -- they're generally

View File

@ -807,6 +807,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
break;
default:
@ -1565,6 +1566,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
CHECK_EQ(new_operands.size(), 1);
clone = CreateUnary(shape, opcode_, new_operands[0]);
@ -1937,6 +1939,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTuple:
@ -2381,6 +2384,7 @@ bool HloInstruction::IsElementwiseImpl(
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
CHECK_EQ(1, operand_count());
return true;
@ -2893,6 +2897,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSin(this);
case HloOpcode::kSqrt:
return visitor->HandleSqrt(this);
case HloOpcode::kCbrt:
return visitor->HandleCbrt(this);
case HloOpcode::kRsqrt:
return visitor->HandleRsqrt(this);
case HloOpcode::kReal:

View File

@ -138,6 +138,7 @@ namespace xla {
V(kSlice, "slice", 1) \
V(kSort, "sort", kHloOpcodeIsVariadic) \
V(kSqrt, "sqrt", 1) \
V(kCbrt, "cbrt", 1) \
V(kSubtract, "subtract", 2) \
V(kTanh, "tanh", 1) \
V(kTrace, "trace", 1) \

View File

@ -784,6 +784,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {

View File

@ -175,6 +175,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kSendDone:
case HloOpcode::kSort:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
case HloOpcode::kTrace:
case HloOpcode::kTriangularSolve:

View File

@ -2220,6 +2220,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kSlice:
case HloOpcode::kSort:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kPopulationCount:

View File

@ -257,6 +257,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
case HloOpcode::kLog1p:
case HloOpcode::kRsqrt:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
case HloOpcode::kTanh:
if (!ShapeUtil::ElementIsFloating(shape) &&
!ShapeUtil::ElementIsComplex(shape)) {

View File

@ -352,6 +352,17 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, {
Run(Sqrt, std::sqrt, error_spec_gen);
})
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Cbrt, {
if (platform_ == "Host" || platform_ == "CUDA") {
ErrorSpecGen error_spec_gen = +[](NativeT x) {
return ErrorSpec{0.01, 0.01};
};
Run(Cbrt, std::cbrt, error_spec_gen);
} else {
Run(Cbrt, std::cbrt);
}
})
// TODO(jlebar): Test trig functions over complex inputs.
XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) {
// Error inherited from Log, which our implementation of Acosh uses.