Override CustomCall in MlirHloBuilder
Also, enable mlir bridge for image ops compilers test. ResizeBilinear op lowering usese CustomCall in case of TPU lowerings. PiperOrigin-RevId: 317272443 Change-Id: I134c828cdc76552a0cbfdeb7c65532aa986314e2
This commit is contained in:
parent
9d1ec55aed
commit
c662daf489
@ -132,6 +132,22 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||
if (operand_shapes_with_layout.has_value())
|
||||
return Unimplemented(
|
||||
"CustomCall doesn't support operands shapes with layout");
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::CustomCallOp>(
|
||||
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
||||
/*has_side_effect=*/builder_.getBoolAttr(false),
|
||||
builder_.getStringAttr(opaque));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
|
@ -124,6 +124,12 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
FftType fft_type,
|
||||
absl::Span<const int64> fft_length) override;
|
||||
|
||||
StatusOr<XlaOp> CustomCallInternal(const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>>
|
||||
operand_shapes_with_layout) override;
|
||||
|
||||
StatusOr<XlaOp> ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
|
@ -88,6 +88,9 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::AddNOp>(),
|
||||
TypeID::get<TF::AddV2Op>(),
|
||||
TypeID::get<TF::AngleOp>(),
|
||||
TypeID::get<TF::AdjustContrastv2Op>(),
|
||||
TypeID::get<TF::AdjustHueOp>(),
|
||||
TypeID::get<TF::AdjustSaturationOp>(),
|
||||
TypeID::get<TF::ApproximateEqualOp>(),
|
||||
TypeID::get<TF::ArgMaxOp>(),
|
||||
TypeID::get<TF::ArgMinOp>(),
|
||||
@ -127,6 +130,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::GatherNdOp>(),
|
||||
TypeID::get<TF::GreaterEqualOp>(),
|
||||
TypeID::get<TF::GreaterOp>(),
|
||||
TypeID::get<TF::HSVToRGBOp>(),
|
||||
TypeID::get<TF::IFFT2DOp>(),
|
||||
TypeID::get<TF::IFFT3DOp>(),
|
||||
TypeID::get<TF::IFFTOp>(),
|
||||
@ -157,10 +161,14 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::PowOp>(),
|
||||
TypeID::get<TF::RFFT2DOp>(),
|
||||
TypeID::get<TF::RFFT3DOp>(),
|
||||
TypeID::get<TF::RGBToHSVOp>(),
|
||||
TypeID::get<TF::RealDivOp>(),
|
||||
TypeID::get<TF::ReciprocalOp>(),
|
||||
TypeID::get<TF::ReciprocalGradOp>(),
|
||||
TypeID::get<TF::Relu6GradOp>(),
|
||||
TypeID::get<TF::ResizeBilinearOp>(),
|
||||
TypeID::get<TF::ResizeBilinearGradOp>(),
|
||||
TypeID::get<TF::ResizeNearestNeighborOp>(),
|
||||
TypeID::get<TF::ReverseSequenceOp>(),
|
||||
TypeID::get<TF::RightShiftOp>(),
|
||||
TypeID::get<TF::RintOp>(),
|
||||
|
@ -770,6 +770,7 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["image_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
|
@ -1564,16 +1564,12 @@ XlaOp XlaBuilder::CustomCall(
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
if (absl::StartsWith(call_target_name, "$")) {
|
||||
return InvalidArgument(
|
||||
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
|
||||
"are reserved for internal use.",
|
||||
call_target_name);
|
||||
}
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_custom_call_target(call_target_name);
|
||||
instr.set_backend_config(opaque);
|
||||
if (operand_shapes_with_layout.has_value()) {
|
||||
if (!LayoutUtil::HasLayout(shape)) {
|
||||
return InvalidArgument(
|
||||
@ -1586,7 +1582,6 @@ XlaOp XlaBuilder::CustomCall(
|
||||
"with constrained layout; given %d shapes, expected %d",
|
||||
operand_shapes_with_layout->size(), operands.size());
|
||||
}
|
||||
instr.set_constrain_layout(true);
|
||||
int64 operand_num = 0;
|
||||
for (const Shape& operand_shape : *operand_shapes_with_layout) {
|
||||
if (!LayoutUtil::HasLayout(operand_shape)) {
|
||||
@ -1595,14 +1590,31 @@ XlaOp XlaBuilder::CustomCall(
|
||||
"constrained layout.",
|
||||
operand_num);
|
||||
}
|
||||
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
||||
++operand_num;
|
||||
}
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
||||
operand_shapes_with_layout);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_custom_call_target(call_target_name);
|
||||
instr.set_backend_config(opaque);
|
||||
if (operand_shapes_with_layout.has_value()) {
|
||||
instr.set_constrain_layout(true);
|
||||
for (const Shape& operand_shape : *operand_shapes_with_layout) {
|
||||
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
||||
}
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::CustomCall(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation, const Shape& shape, const string& opaque,
|
||||
|
@ -527,6 +527,14 @@ class XlaBuilder {
|
||||
const Shape& shape_with_layout, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
||||
|
||||
// Internal version of CustomCall without computation that doesn't do op
|
||||
// specific error handling and expects arguments to be legal. CustomCall
|
||||
// method above calls this method after error handling.
|
||||
virtual StatusOr<XlaOp> CustomCallInternal(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape_with_layout, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
||||
|
||||
XlaOp CustomCall(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation, const Shape& shape_with_layout,
|
||||
|
Loading…
Reference in New Issue
Block a user