Lower TensorFlow Atan2 op using tf2xla kernel
PiperOrigin-RevId: 299884500 Change-Id: I88e8d3a0844d38326d08a63cb6792039e0ef1c82
This commit is contained in:
parent
2e1e647efb
commit
4118672819
@ -59,18 +59,36 @@ XlaOp MlirHloBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
|
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
|
||||||
|
|
||||||
mlir::Value value = GetValue(operand);
|
return CreateOp(GetMlirOpName(unop), shape, {operand}, /*attributes=*/{});
|
||||||
mlir::OperationState state(loc_, GetMlirOpName(unop));
|
|
||||||
state.addOperands(value);
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
mlir::Type ty,
|
|
||||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
|
||||||
state.addTypes(ty);
|
|
||||||
mlir::Operation* op = builder_.createOperation(state);
|
|
||||||
return MakeXlaOp(op->getResult(0));
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp MlirHloBuilder::BinaryOpNoBroadcast(
|
||||||
|
HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::optional<ComparisonDirection> direction) {
|
||||||
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
if (direction.has_value())
|
||||||
|
return Unimplemented("direction attribute not yet supported");
|
||||||
|
return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}, /*attributes=*/{});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> MlirHloBuilder::CreateOp(
|
||||||
|
const std::string& op_name, const Shape& shape,
|
||||||
|
llvm::ArrayRef<XlaOp> operands,
|
||||||
|
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
|
||||||
|
llvm::SmallVector<mlir::Value, 4> operand_values;
|
||||||
|
operand_values.reserve(operands.size());
|
||||||
|
for (XlaOp xla_op : operands) {
|
||||||
|
operand_values.push_back(GetValue(xla_op));
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||||
|
shape, builder_));
|
||||||
|
mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
|
||||||
|
mlir::Operation* op = builder_.createOperation(state);
|
||||||
|
return MakeXlaOp(op->getResult(0));
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
|
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
|
||||||
TF_RETURN_IF_ERROR(first_error());
|
TF_RETURN_IF_ERROR(first_error());
|
||||||
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
|
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
@ -89,6 +90,15 @@ class MlirHloBuilder : public XlaBuilder {
|
|||||||
private:
|
private:
|
||||||
XlaOp UnaryOp(HloOpcode unop, XlaOp operand) override;
|
XlaOp UnaryOp(HloOpcode unop, XlaOp operand) override;
|
||||||
|
|
||||||
|
XlaOp BinaryOpNoBroadcast(
|
||||||
|
HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::optional<ComparisonDirection> direction) override;
|
||||||
|
|
||||||
|
// Creates HLO dialect op and returns the result as an XlaOp.
|
||||||
|
StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape,
|
||||||
|
llvm::ArrayRef<XlaOp> operands,
|
||||||
|
llvm::ArrayRef<mlir::NamedAttribute> attributes);
|
||||||
|
|
||||||
mlir::OpBuilder builder_;
|
mlir::OpBuilder builder_;
|
||||||
mlir::Location loc_;
|
mlir::Location loc_;
|
||||||
|
|
||||||
|
@ -42,6 +42,13 @@ func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
return %1 : tensor<2xf32>
|
return %1 : tensor<2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: binary_op
|
||||||
|
func @binary_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
// CHECK: xla_hlo.atan2 %arg0, %arg1 : tensor<2xf32>
|
||||||
|
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Add a test with variant type once one of the ops supporting
|
// TODO(hinsu): Add a test with variant type once one of the ops supporting
|
||||||
// the type is whitelisted. It should be rejected with unsupported type remark.
|
// the type is whitelisted. It should be rejected with unsupported type remark.
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||||||
// building valid MLIR using MlirHloBuilder.
|
// building valid MLIR using MlirHloBuilder.
|
||||||
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
||||||
// all tf2xla kernels.
|
// all tf2xla kernels.
|
||||||
return isa<TF::AbsOp>(op);
|
return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static llvm::Optional<absl::string_view> GetJitDevice(
|
static llvm::Optional<absl::string_view> GetJitDevice(
|
||||||
|
Loading…
Reference in New Issue
Block a user