Override CustomCall in MlirHloBuilder

Also, enable mlir bridge for image ops compilers test. ResizeBilinear op
lowering usese CustomCall in case of TPU lowerings.

PiperOrigin-RevId: 317272443
Change-Id: I134c828cdc76552a0cbfdeb7c65532aa986314e2
This commit is contained in:
Smit Hinsu 2020-06-19 02:22:43 -07:00 committed by TensorFlower Gardener
parent 9d1ec55aed
commit c662daf489
6 changed files with 58 additions and 7 deletions

View File

@ -132,6 +132,22 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(false),
builder_.getStringAttr(opaque));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,

View File

@ -124,6 +124,12 @@ class MlirHloBuilder : public XlaBuilder {
FftType fft_type,
absl::Span<const int64> fft_length) override;
StatusOr<XlaOp> CustomCallInternal(const string& call_target_name,
absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>>
operand_shapes_with_layout) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,

View File

@ -88,6 +88,9 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::AddNOp>(),
TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::AngleOp>(),
TypeID::get<TF::AdjustContrastv2Op>(),
TypeID::get<TF::AdjustHueOp>(),
TypeID::get<TF::AdjustSaturationOp>(),
TypeID::get<TF::ApproximateEqualOp>(),
TypeID::get<TF::ArgMaxOp>(),
TypeID::get<TF::ArgMinOp>(),
@ -127,6 +130,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::GatherNdOp>(),
TypeID::get<TF::GreaterEqualOp>(),
TypeID::get<TF::GreaterOp>(),
TypeID::get<TF::HSVToRGBOp>(),
TypeID::get<TF::IFFT2DOp>(),
TypeID::get<TF::IFFT3DOp>(),
TypeID::get<TF::IFFTOp>(),
@ -157,10 +161,14 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::PowOp>(),
TypeID::get<TF::RFFT2DOp>(),
TypeID::get<TF::RFFT3DOp>(),
TypeID::get<TF::RGBToHSVOp>(),
TypeID::get<TF::RealDivOp>(),
TypeID::get<TF::ReciprocalOp>(),
TypeID::get<TF::ReciprocalGradOp>(),
TypeID::get<TF::Relu6GradOp>(),
TypeID::get<TF::ResizeBilinearOp>(),
TypeID::get<TF::ResizeBilinearGradOp>(),
TypeID::get<TF::ResizeNearestNeighborOp>(),
TypeID::get<TF::ReverseSequenceOp>(),
TypeID::get<TF::RightShiftOp>(),
TypeID::get<TF::RintOp>(),

View File

@ -770,6 +770,7 @@ tf_xla_py_test(
size = "small",
timeout = "long",
srcs = ["image_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 10,
tags = [

View File

@ -1564,16 +1564,12 @@ XlaOp XlaBuilder::CustomCall(
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
"are reserved for internal use.",
call_target_name);
}
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (operand_shapes_with_layout.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument(
@ -1586,7 +1582,6 @@ XlaOp XlaBuilder::CustomCall(
"with constrained layout; given %d shapes, expected %d",
operand_shapes_with_layout->size(), operands.size());
}
instr.set_constrain_layout(true);
int64 operand_num = 0;
for (const Shape& operand_shape : *operand_shapes_with_layout) {
if (!LayoutUtil::HasLayout(operand_shape)) {
@ -1595,14 +1590,31 @@ XlaOp XlaBuilder::CustomCall(
"constrained layout.",
operand_num);
}
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
++operand_num;
}
}
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout);
});
}
StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (operand_shapes_with_layout.has_value()) {
instr.set_constrain_layout(true);
for (const Shape& operand_shape : *operand_shapes_with_layout) {
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
}
}
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
}
XlaOp XlaBuilder::CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape, const string& opaque,

View File

@ -527,6 +527,14 @@ class XlaBuilder {
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
// method above calls this method after error handling.
virtual StatusOr<XlaOp> CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape_with_layout,