diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 32796dd8d70..9b8156efe5b 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -298,6 +298,15 @@ XLA_TEST_F(MathTest, SqrtSixValues) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, CbrtSixValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {8.0, 1.0, 4096.0, -64.0, 1.728, 1331}); + Cbrt(x); + + std::vector expected = {2, 1, 16, -4, 1.2, 11}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.001)); +} + XLA_TEST_F(MathTest, SinhSmallValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7de4cd4b3c7..a4893acb546 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -3571,6 +3571,9 @@ XlaOp Imag(const XlaOp operand) { XlaOp Sqrt(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand); } +XlaOp Cbrt(const XlaOp operand) { + return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand); +} XlaOp Rsqrt(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 64424b9dd3c..67cff09caed 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -1030,6 +1030,7 @@ class XlaBuilder { friend XlaOp Imag(XlaOp operand); friend XlaOp Sqrt(XlaOp operand); friend XlaOp Rsqrt(XlaOp operand); + friend XlaOp Cbrt(XlaOp operand); friend XlaOp Pow(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); friend XlaOp IsFinite(XlaOp operand); @@ -1884,6 +1885,9 @@ XlaOp Imag(XlaOp operand); // Enqueues a sqrt computation onto the computation. XlaOp Sqrt(XlaOp operand); +// Enqueues a cbrt computation onto the computation. +XlaOp Cbrt(XlaOp operand); + // Enqueues a rsqrt computation onto the computation. XlaOp Rsqrt(XlaOp operand); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index e4676141f65..cadea620ec6 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -109,6 +109,9 @@ class DfsHloVisitorBase { virtual Status HandleRsqrt(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleCbrt(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 3eb6dab3129..30300b8c195 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -461,6 +461,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitSqrt(op->shape().element_type(), operand_value); case HloOpcode::kRsqrt: return EmitRsqrt(op->shape().element_type(), operand_value); + case HloOpcode::kCbrt: + return EmitCbrt(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {operand_value}, @@ -787,6 +789,9 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( case HloOpcode::kRsqrt: { return EmitComplexRsqrt(op, component_type, operand_value); } + case HloOpcode::kCbrt: { + return EmitComplexCbrt(op, component_type, operand_value); + } case HloOpcode::kNegate: return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), FNeg(EmitExtractImag(operand_value))); @@ -1081,6 +1086,19 @@ StatusOr ElementalIrEmitter::EmitComplexRsqrt( return EmitComposeComplex(op, real_part, imag_part); } +// +// Using EmitComplexPower with c=1.0/3.0 and d=0 +StatusOr ElementalIrEmitter::EmitComplexCbrt( + const HloInstruction* op, PrimitiveType prim_type, + llvm::Value* operand_value) { + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); + auto zero = llvm::ConstantFP::get(type, 0); + llvm::Value* a = EmitExtractReal(operand_value); + llvm::Value* b = EmitExtractImag(operand_value); + return EmitComplexPower(op, a, b, third, zero); +} + // (a+bi)^(c+di) = // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) @@ -1392,6 +1410,19 @@ StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, {lhs->getType()}, b_); } +StatusOr ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type, + llvm::Value* value) { + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); + auto abs_value = + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); + TF_ASSIGN_OR_RETURN(llvm::Value * abs_res, + EmitPow(prim_type, abs_value, third)); + auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, + {abs_res, value}, {type}, b_); + return signed_res; +} + StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) { @@ -2181,6 +2212,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 99833a5525f..94e8f1d6400 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -116,6 +116,9 @@ class ElementalIrEmitter : public IrBuilderMixin { virtual StatusOr EmitSqrt(PrimitiveType prim_type, llvm::Value* value); + virtual StatusOr EmitCbrt(PrimitiveType prim_type, + llvm::Value* value); + virtual StatusOr EmitRsqrt(PrimitiveType prim_type, llvm::Value* value); @@ -159,6 +162,10 @@ class ElementalIrEmitter : public IrBuilderMixin { PrimitiveType prim_type, llvm::Value* operand_value); + virtual StatusOr EmitComplexCbrt(const HloInstruction* op, + PrimitiveType prim_type, + llvm::Value* operand_value); + virtual StatusOr EmitComplexRsqrt(const HloInstruction* op, PrimitiveType prim_type, llvm::Value* operand_value); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index e105ea8ce18..3dc9cc24734 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -700,6 +700,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCbrt(HloInstruction* cbrt) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[cbrt], + ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) -> ElementwiseT { + return std::pow(elem_operand, static_cast(1.0 / 3.0)); + return elem_operand.real() < 0 + ? -std::pow(-elem_operand, + static_cast(1.0 / 3.0)) + : std::pow(elem_operand, + static_cast(1.0 / 3.0)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCbrt(HloInstruction* cbrt) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cbrt], + ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) { + return std::cbrt(elem_operand); + })); + return Status::OK(); + } + + Status HandleCbrt(HloInstruction* cbrt) override { + return HandleCbrt(cbrt); + } + Status HandleRsqrt(HloInstruction* rsqrt) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[rsqrt], diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 78e4d39d3fe..47a455ac3f4 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -980,6 +980,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: // De-emphasize scalar-shaped elementwise ops -- they're generally diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5fc42eb5e3c..27fac19587e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -807,6 +807,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: break; default: @@ -1565,6 +1566,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1937,6 +1939,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: @@ -2381,6 +2384,7 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: CHECK_EQ(1, operand_count()); return true; @@ -2893,6 +2897,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSin(this); case HloOpcode::kSqrt: return visitor->HandleSqrt(this); + case HloOpcode::kCbrt: + return visitor->HandleCbrt(this); case HloOpcode::kRsqrt: return visitor->HandleRsqrt(this); case HloOpcode::kReal: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index dfe68d93f30..2d66237de59 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -138,6 +138,7 @@ namespace xla { V(kSlice, "slice", 1) \ V(kSort, "sort", kHloOpcodeIsVariadic) \ V(kSqrt, "sqrt", 1) \ + V(kCbrt, "cbrt", 1) \ V(kSubtract, "subtract", 2) \ V(kTanh, "tanh", 1) \ V(kTrace, "trace", 1) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 4162c5d62d5..a9c3cacc4c4 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -784,6 +784,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 53938a489f1..99242c9ca21 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -175,6 +175,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kSendDone: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kTriangularSolve: diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 64390e77ddb..a67c677bd03 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2220,6 +2220,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kPopulationCount: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d2cbdddff2e..f3c8eec1751 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -257,6 +257,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kLog1p: case HloOpcode::kRsqrt: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc index 0ed79fa0ad8..44e1b7b5a6f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc @@ -352,6 +352,17 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, { Run(Sqrt, std::sqrt, error_spec_gen); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Cbrt, { + if (platform_ == "Host" || platform_ == "CUDA") { + ErrorSpecGen error_spec_gen = +[](NativeT x) { + return ErrorSpec{0.01, 0.01}; + }; + Run(Cbrt, std::cbrt, error_spec_gen); + } else { + Run(Cbrt, std::cbrt); + } +}) + // TODO(jlebar): Test trig functions over complex inputs. XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) { // Error inherited from Log, which our implementation of Acosh uses.