Lower TensorFlow Atan2 op using tf2xla kernel

PiperOrigin-RevId: 299884500
Change-Id: I88e8d3a0844d38326d08a63cb6792039e0ef1c82
This commit is contained in:
Smit Hinsu 2020-03-09 11:06:52 -07:00 committed by TensorFlower Gardener
parent 2e1e647efb
commit 4118672819
4 changed files with 45 additions and 10 deletions

View File

@ -59,18 +59,36 @@ XlaOp MlirHloBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
mlir::Value value = GetValue(operand);
mlir::OperationState state(loc_, GetMlirOpName(unop));
state.addOperands(value);
TF_ASSIGN_OR_RETURN(
mlir::Type ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
state.addTypes(ty);
mlir::Operation* op = builder_.createOperation(state);
return MakeXlaOp(op->getResult(0));
return CreateOp(GetMlirOpName(unop), shape, {operand}, /*attributes=*/{});
});
}
XlaOp MlirHloBuilder::BinaryOpNoBroadcast(
HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
absl::optional<ComparisonDirection> direction) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (direction.has_value())
return Unimplemented("direction attribute not yet supported");
return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}, /*attributes=*/{});
});
}
StatusOr<XlaOp> MlirHloBuilder::CreateOp(
const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands,
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
llvm::SmallVector<mlir::Value, 4> operand_values;
operand_values.reserve(operands.size());
for (XlaOp xla_op : operands) {
operand_values.push_back(GetValue(xla_op));
}
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
mlir::OperationState state(loc_, op_name, operand_values, {ty}, attributes);
mlir::Operation* op = builder_.createOperation(state);
return MakeXlaOp(op->getResult(0));
}
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
TF_RETURN_IF_ERROR(first_error());
TF_RETURN_IF_ERROR(CheckOpBuilder(op));

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
@ -89,6 +90,15 @@ class MlirHloBuilder : public XlaBuilder {
private:
XlaOp UnaryOp(HloOpcode unop, XlaOp operand) override;
XlaOp BinaryOpNoBroadcast(
HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs,
absl::optional<ComparisonDirection> direction) override;
// Creates HLO dialect op and returns the result as an XlaOp.
StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands,
llvm::ArrayRef<mlir::NamedAttribute> attributes);
mlir::OpBuilder builder_;
mlir::Location loc_;

View File

@ -42,6 +42,13 @@ func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
return %1 : tensor<2xf32>
}
// CHECK-LABEL: binary_op
func @binary_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: xla_hlo.atan2 %arg0, %arg1 : tensor<2xf32>
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// TODO(hinsu): Add a test with variant type once one of the ops supporting
// the type is whitelisted. It should be rejected with unsupported type remark.

View File

@ -77,7 +77,7 @@ static bool IsOpWhitelisted(Operation* op) {
// building valid MLIR using MlirHloBuilder.
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
// all tf2xla kernels.
return isa<TF::AbsOp>(op);
return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op);
}
static llvm::Optional<absl::string_view> GetJitDevice(