From c662daf4891a1e6efe64797615c3bd2bebedc5f5 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 19 Jun 2020 02:22:43 -0700 Subject: [PATCH] Override CustomCall in MlirHloBuilder Also, enable mlir bridge for image ops compilers test. ResizeBilinear op lowering usese CustomCall in case of TPU lowerings. PiperOrigin-RevId: 317272443 Change-Id: I134c828cdc76552a0cbfdeb7c65532aa986314e2 --- .../compiler/mlir/xla/ir/mlir_hlo_builder.cc | 16 ++++++++++++ .../compiler/mlir/xla/ir/mlir_hlo_builder.h | 6 +++++ .../xla/transforms/legalize_tf_with_tf2xla.cc | 8 ++++++ tensorflow/compiler/tests/BUILD | 1 + tensorflow/compiler/xla/client/xla_builder.cc | 26 ++++++++++++++----- tensorflow/compiler/xla/client/xla_builder.h | 8 ++++++ 6 files changed, 58 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 21b1ac5f0ea..3c11d8e590d 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -132,6 +132,22 @@ StatusOr MlirHloBuilder::FftInternal( return MakeXlaOp(op); } +StatusOr MlirHloBuilder::CustomCallInternal( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { + if (operand_shapes_with_layout.has_value()) + return Unimplemented( + "CustomCall doesn't support operands shapes with layout"); + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name), + /*has_side_effect=*/builder_.getBoolAttr(false), + builder_.getStringAttr(opaque)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::ReduceInternal( const Shape& shape, absl::Span all_operands, const XlaComputation& computation, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 4b28c32db99..4d7d93af7a7 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -124,6 +124,12 @@ class MlirHloBuilder : public XlaBuilder { FftType fft_type, absl::Span fft_length) override; + StatusOr CustomCallInternal(const string& call_target_name, + absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> + operand_shapes_with_layout) override; + StatusOr ReduceInternal( const Shape& shape, absl::Span all_operands, const XlaComputation& computation, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index ef79c8868bb..8f96f4d1305 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -88,6 +88,9 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -127,6 +130,7 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -157,10 +161,14 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index b574622efce..034ec82de10 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -770,6 +770,7 @@ tf_xla_py_test( size = "small", timeout = "long", srcs = ["image_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index c7b6a7f9491..03ae23ea18b 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1564,16 +1564,12 @@ XlaOp XlaBuilder::CustomCall( const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", call_target_name); } - *instr.mutable_shape() = shape.ToProto(); - instr.set_custom_call_target(call_target_name); - instr.set_backend_config(opaque); if (operand_shapes_with_layout.has_value()) { if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument( @@ -1586,7 +1582,6 @@ XlaOp XlaBuilder::CustomCall( "with constrained layout; given %d shapes, expected %d", operand_shapes_with_layout->size(), operands.size()); } - instr.set_constrain_layout(true); int64 operand_num = 0; for (const Shape& operand_shape : *operand_shapes_with_layout) { if (!LayoutUtil::HasLayout(operand_shape)) { @@ -1595,14 +1590,31 @@ XlaOp XlaBuilder::CustomCall( "constrained layout.", operand_num); } - *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); ++operand_num; } } - return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); + return CustomCallInternal(call_target_name, operands, shape, opaque, + operand_shapes_with_layout); }); } +StatusOr XlaBuilder::CustomCallInternal( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_custom_call_target(call_target_name); + instr.set_backend_config(opaque); + if (operand_shapes_with_layout.has_value()) { + instr.set_constrain_layout(true); + for (const Shape& operand_shape : *operand_shapes_with_layout) { + *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); + } + } + return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); +} + XlaOp XlaBuilder::CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape, const string& opaque, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index b8af180b83e..3fc26747468 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -527,6 +527,14 @@ class XlaBuilder { const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); + // Internal version of CustomCall without computation that doesn't do op + // specific error handling and expects arguments to be legal. CustomCall + // method above calls this method after error handling. + virtual StatusOr CustomCallInternal( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const string& opaque, + absl::optional> operand_shapes_with_layout); + XlaOp CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape_with_layout,