[XLA] Add Cbrt operation (lowered to pow(1.0/3.0))
First phase of my onboarding bug. Still need to improve internal calculation of cbrt. PiperOrigin-RevId: 309533789 Change-Id: I75058b7f319e32e51e85324ca51515ef802cc111
This commit is contained in:
parent
8cb03a2837
commit
c75e73c2a5
@ -298,6 +298,15 @@ XLA_TEST_F(MathTest, SqrtSixValues) {
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, CbrtSixValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {8.0, 1.0, 4096.0, -64.0, 1.728, 1331});
|
||||
Cbrt(x);
|
||||
|
||||
std::vector<float> expected = {2, 1, 16, -4, 1.2, 11};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, SinhSmallValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
|
||||
|
||||
@ -3571,6 +3571,9 @@ XlaOp Imag(const XlaOp operand) {
|
||||
XlaOp Sqrt(const XlaOp operand) {
|
||||
return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
|
||||
}
|
||||
XlaOp Cbrt(const XlaOp operand) {
|
||||
return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
|
||||
}
|
||||
XlaOp Rsqrt(const XlaOp operand) {
|
||||
return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
|
||||
}
|
||||
|
||||
@ -1030,6 +1030,7 @@ class XlaBuilder {
|
||||
friend XlaOp Imag(XlaOp operand);
|
||||
friend XlaOp Sqrt(XlaOp operand);
|
||||
friend XlaOp Rsqrt(XlaOp operand);
|
||||
friend XlaOp Cbrt(XlaOp operand);
|
||||
friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp IsFinite(XlaOp operand);
|
||||
@ -1884,6 +1885,9 @@ XlaOp Imag(XlaOp operand);
|
||||
// Enqueues a sqrt computation onto the computation.
|
||||
XlaOp Sqrt(XlaOp operand);
|
||||
|
||||
// Enqueues a cbrt computation onto the computation.
|
||||
XlaOp Cbrt(XlaOp operand);
|
||||
|
||||
// Enqueues a rsqrt computation onto the computation.
|
||||
XlaOp Rsqrt(XlaOp operand);
|
||||
|
||||
|
||||
@ -109,6 +109,9 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleRsqrt(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
virtual Status HandleCbrt(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleFft(HloInstructionPtr fft) = 0;
|
||||
virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;
|
||||
|
||||
@ -461,6 +461,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
||||
return EmitSqrt(op->shape().element_type(), operand_value);
|
||||
case HloOpcode::kRsqrt:
|
||||
return EmitRsqrt(op->shape().element_type(), operand_value);
|
||||
case HloOpcode::kCbrt:
|
||||
return EmitCbrt(op->shape().element_type(), operand_value);
|
||||
case HloOpcode::kFloor:
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
|
||||
{operand_value},
|
||||
@ -787,6 +789,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
|
||||
case HloOpcode::kRsqrt: {
|
||||
return EmitComplexRsqrt(op, component_type, operand_value);
|
||||
}
|
||||
case HloOpcode::kCbrt: {
|
||||
return EmitComplexCbrt(op, component_type, operand_value);
|
||||
}
|
||||
case HloOpcode::kNegate:
|
||||
return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
|
||||
FNeg(EmitExtractImag(operand_value)));
|
||||
@ -1081,6 +1086,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
|
||||
return EmitComposeComplex(op, real_part, imag_part);
|
||||
}
|
||||
|
||||
//
|
||||
// Using EmitComplexPower with c=1.0/3.0 and d=0
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
|
||||
const HloInstruction* op, PrimitiveType prim_type,
|
||||
llvm::Value* operand_value) {
|
||||
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
|
||||
auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
|
||||
auto zero = llvm::ConstantFP::get(type, 0);
|
||||
llvm::Value* a = EmitExtractReal(operand_value);
|
||||
llvm::Value* b = EmitExtractImag(operand_value);
|
||||
return EmitComplexPower(op, a, b, third, zero);
|
||||
}
|
||||
|
||||
// (a+bi)^(c+di) =
|
||||
// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
|
||||
// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
|
||||
@ -1392,6 +1410,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
|
||||
{lhs->getType()}, b_);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
|
||||
llvm::Value* value) {
|
||||
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
|
||||
auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
|
||||
auto abs_value =
|
||||
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
|
||||
EmitPow(prim_type, abs_value, third));
|
||||
auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
|
||||
{abs_res, value}, {type}, b_);
|
||||
return signed_res;
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
||||
llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
@ -2181,6 +2212,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh:
|
||||
return [this, hlo, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
|
||||
@ -116,6 +116,9 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
virtual StatusOr<llvm::Value*> EmitSqrt(PrimitiveType prim_type,
|
||||
llvm::Value* value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitCbrt(PrimitiveType prim_type,
|
||||
llvm::Value* value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitRsqrt(PrimitiveType prim_type,
|
||||
llvm::Value* value);
|
||||
|
||||
@ -159,6 +162,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
PrimitiveType prim_type,
|
||||
llvm::Value* operand_value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitComplexCbrt(const HloInstruction* op,
|
||||
PrimitiveType prim_type,
|
||||
llvm::Value* operand_value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitComplexRsqrt(const HloInstruction* op,
|
||||
PrimitiveType prim_type,
|
||||
llvm::Value* operand_value);
|
||||
|
||||
@ -700,6 +700,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleCbrt(HloInstruction* cbrt) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[cbrt],
|
||||
ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) -> ElementwiseT {
|
||||
return std::pow(elem_operand, static_cast<ElementwiseT>(1.0 / 3.0));
|
||||
return elem_operand.real() < 0
|
||||
? -std::pow(-elem_operand,
|
||||
static_cast<ElementwiseT>(1.0 / 3.0))
|
||||
: std::pow(elem_operand,
|
||||
static_cast<ElementwiseT>(1.0 / 3.0));
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleCbrt(HloInstruction* cbrt) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[cbrt],
|
||||
ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) {
|
||||
return std::cbrt(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleCbrt(HloInstruction* cbrt) override {
|
||||
return HandleCbrt<ElementwiseT>(cbrt);
|
||||
}
|
||||
|
||||
Status HandleRsqrt(HloInstruction* rsqrt) override {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[rsqrt],
|
||||
|
||||
@ -980,6 +980,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kTanh:
|
||||
// De-emphasize scalar-shaped elementwise ops -- they're generally
|
||||
|
||||
@ -807,6 +807,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh:
|
||||
break;
|
||||
default:
|
||||
@ -1565,6 +1566,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh:
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
clone = CreateUnary(shape, opcode_, new_operands[0]);
|
||||
@ -1937,6 +1939,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kTanh:
|
||||
case HloOpcode::kTuple:
|
||||
@ -2381,6 +2384,7 @@ bool HloInstruction::IsElementwiseImpl(
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh:
|
||||
CHECK_EQ(1, operand_count());
|
||||
return true;
|
||||
@ -2893,6 +2897,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleSin(this);
|
||||
case HloOpcode::kSqrt:
|
||||
return visitor->HandleSqrt(this);
|
||||
case HloOpcode::kCbrt:
|
||||
return visitor->HandleCbrt(this);
|
||||
case HloOpcode::kRsqrt:
|
||||
return visitor->HandleRsqrt(this);
|
||||
case HloOpcode::kReal:
|
||||
|
||||
@ -138,6 +138,7 @@ namespace xla {
|
||||
V(kSlice, "slice", 1) \
|
||||
V(kSort, "sort", kHloOpcodeIsVariadic) \
|
||||
V(kSqrt, "sqrt", 1) \
|
||||
V(kCbrt, "cbrt", 1) \
|
||||
V(kSubtract, "subtract", 2) \
|
||||
V(kTanh, "tanh", 1) \
|
||||
V(kTrace, "trace", 1) \
|
||||
|
||||
@ -784,6 +784,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh: {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
|
||||
@ -175,6 +175,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kTriangularSolve:
|
||||
|
||||
@ -2220,6 +2220,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kTanh:
|
||||
case HloOpcode::kPopulationCount:
|
||||
|
||||
@ -257,6 +257,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
case HloOpcode::kTanh:
|
||||
if (!ShapeUtil::ElementIsFloating(shape) &&
|
||||
!ShapeUtil::ElementIsComplex(shape)) {
|
||||
|
||||
@ -352,6 +352,17 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, {
|
||||
Run(Sqrt, std::sqrt, error_spec_gen);
|
||||
})
|
||||
|
||||
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Cbrt, {
|
||||
if (platform_ == "Host" || platform_ == "CUDA") {
|
||||
ErrorSpecGen error_spec_gen = +[](NativeT x) {
|
||||
return ErrorSpec{0.01, 0.01};
|
||||
};
|
||||
Run(Cbrt, std::cbrt, error_spec_gen);
|
||||
} else {
|
||||
Run(Cbrt, std::cbrt);
|
||||
}
|
||||
})
|
||||
|
||||
// TODO(jlebar): Test trig functions over complex inputs.
|
||||
XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) {
|
||||
// Error inherited from Log, which our implementation of Acosh uses.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user