From eb377d252e30292ba620ddb61deb81496f7a0f95 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 17 Aug 2020 13:51:58 -0700 Subject: [PATCH] Override AddInstruction method in MlirHloBuilder This way all the supported ops can return an error if MLIR builder is used. PiperOrigin-RevId: 327091802 Change-Id: I2df09131f89022f21243b173501cfb8dde0573c2 --- tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc | 7 +++++++ tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h | 3 +++ tensorflow/compiler/xla/client/xla_builder.h | 9 +++++++-- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 1b272e946b6..3fa3746598e 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -351,6 +351,13 @@ StatusOr MlirHloBuilder::InDimBroadcast( return MakeXlaOp(op.getResult()); } +StatusOr MlirHloBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands) { + return Unimplemented("MlirHloBuilder does not support op %s", + HloOpcodeString(opcode)); +} + StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) { diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index eebdb18b6ab..3884689e48d 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -196,6 +196,9 @@ class MlirHloBuilder : public XlaBuilder { const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions) override; + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands) override; + StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) override; diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 997187785fd..d812b35f7a0 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -784,8 +784,13 @@ class XlaBuilder { XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); - StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, - absl::Span operands = {}); + virtual StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode, + absl::Span operands); + StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode) { + return AddInstruction(std::move(instr), opcode, /*operands=*/{}); + } void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr);