diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 443011f3cf3..a39c3265206 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -253,9 +253,8 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface { } }; -struct TensorFlowLiteOpFolderDialectInterface - : public OpFolderDialectInterface { - using OpFolderDialectInterface::OpFolderDialectInterface; +struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; // Registered hook to check if the given region, which is attached to an // operation that is *not* isolated from above (i.e. no internal regions @@ -275,7 +274,7 @@ TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context) #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" >(); addInterfaces(); + TensorFlowLiteDialectFoldInterface>(); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc index 2be67f8e93e..d04323f1b70 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc @@ -32,7 +32,10 @@ void init_types(py::module& m) { [](mlir::FunctionType& ft) { return ft.getResults().vec(); }); py::class_(m, "FloatType") - .def("get", &mlir::FloatType::get); + .def("getBF16", &mlir::FloatType::getBF16) + .def("getF16", &mlir::FloatType::getF16) + .def("getF32", &mlir::FloatType::getF32) + .def("getF64", &mlir::FloatType::getF64); py::class_(m, "IntegerType") .def("get", py::overload_cast( diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 319de8d491a..a36f6f9b92e 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1271,7 +1271,7 @@ cc_library( name = "tf_dialect_passes", srcs = [ "transforms/constant_fold.cc", - "transforms/dialect_hooks.cc", + "transforms/decode_attributes_hook.cc", ], hdrs = [ "transforms/constant_fold.h", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 70b7724deeb..ea9ae5d9477 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -72,9 +72,8 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { } }; -struct TensorFlowExecutorOpFolderDialectInterface - : public OpFolderDialectInterface { - using OpFolderDialectInterface::OpFolderDialectInterface; +struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; // Registered hook to check if the given region, which is attached to an // operation that is *not* isolated from above (i.e. no internal regions @@ -97,7 +96,7 @@ TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context) >(); addInterfaces(); + TensorFlowExecutorDialectFoldInterface>(); addTypes(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 6cacd5105ca..6fd3bfc9ccb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -55,6 +55,8 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/DecodeAttributesInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -112,6 +114,22 @@ bool HasSingleUse(FuncOp func) { return true; } +struct TFConstantFoldInterface : public DialectFoldInterface { + TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {} + LogicalResult Fold(Operation *op, ArrayRef operands, + SmallVectorImpl &results) const final { + return TensorFlowDialect::constantFold(op, operands, results); + } +}; + +struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface { + TFDecodeAttributesInterface(Dialect *dialect) + : DialectDecodeAttributesInterface(dialect) {} + LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) const { + return TensorFlowDialect::decode(input, output); + } +}; + struct TFInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -206,6 +224,9 @@ std::vector *TensorFlowDialect::additional_operation_hooks_ = new std::vector(); +TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_; +TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_; + TensorFlowDialect::TensorFlowDialect(MLIRContext *context) : Dialect(/*name=*/"tf", context, TypeID::get()) { addOperations< @@ -217,7 +238,8 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" >(); - addInterfaces(); + addInterfaces(); addAttributes(); // Support unknown operations because not all TensorFlow operations are diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index bbcce4ee177..3169f7fba8d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -116,10 +116,35 @@ class TensorFlowDialect : public Dialect { 0, (addOperation(AbstractOperation::get(*this)), 0)...}; } + using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef, + SmallVectorImpl &); + static void RegisterConstantFoldHook(ConstantFoldHook fn) { + constant_fold_hook_ = std::move(fn); + } + + static LogicalResult constantFold(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + if (constant_fold_hook_) return constant_fold_hook_(op, operands, results); + return failure(); + } + + using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input, + ElementsAttr &output); + static void RegisterDecodeConstantHook(DecodeConstantHook fn) { + decode_constant_hook_ = std::move(fn); + } + static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) { + if (decode_constant_hook_) return decode_constant_hook_(input, output); + return failure(); + } + private: // Hook functions which may add additional operations to the dialect. // These are invoked at construction time. static std::vector *additional_operation_hooks_; + + static ConstantFoldHook constant_fold_hook_; + static DecodeConstantHook decode_constant_hook_; }; } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 1429e2b3fd4..3005c78c54f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" @@ -68,7 +69,7 @@ static bool ShouldBeFolded(Operation* inst) { LogicalResult ConstantFoldFallbackHook( Operation* inst, ArrayRef operands, - SmallVectorImpl& results) { // NOLINT + SmallVectorImpl& results) { // NOLINT // Instructions with side effects should not be constant folded to preserve // the original semantics. if (inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) @@ -126,8 +127,16 @@ LogicalResult ConstantFoldFallbackHook( // TODO(jpienaar): Avoid using global context & mutex here. static auto* mu = new tensorflow::mutex(); tensorflow::mutex_lock l(*mu); - return tensorflow::EvaluateOperation(inst, inputs, ctx, &results); + SmallVector constants; + LogicalResult status = + tensorflow::EvaluateOperation(inst, inputs, ctx, &constants); + results.assign(constants.begin(), constants.end()); + return status; } +static bool init_hooks = ([] () { + TensorFlowDialect::RegisterConstantFoldHook(ConstantFoldFallbackHook); +}(), true); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h index 69e39080965..887eea745e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h @@ -27,7 +27,7 @@ namespace TF { LogicalResult ConstantFoldFallbackHook( Operation *inst, ArrayRef operands, - SmallVectorImpl &results); // NOLINT + SmallVectorImpl &results); // NOLINT } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc similarity index 74% rename from tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc index 109ceea47e7..d309c6d379f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc @@ -19,7 +19,6 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/DialectHooks.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -35,31 +34,22 @@ namespace { // Since this method is passed to MLIR as decode hook it has to conform // to LLVM style used by MLIR. -bool DecodeOpaqueTensorHook(const OpaqueElementsAttr input, - ElementsAttr& output) { // NOLINT +LogicalResult DecodeOpaqueTensorHook(const OpaqueElementsAttr input, + ElementsAttr& output) { // NOLINT Builder builder(input.getType().getContext()); auto decoded_attr_or = tensorflow::DecodeOpaqueTensor(input, builder); if (!decoded_attr_or.ok()) { VLOG(2) << decoded_attr_or.status().error_message(); - return true; + return failure(); } output = decoded_attr_or.ValueOrDie(); - return false; + return success(); } -// Hooks for the TensorFlow dialect. -class TensorFlowHooks : public DialectHooks { - public: - DialectConstantFoldHook getConstantFoldHook() { - return TF::ConstantFoldFallbackHook; - } - DialectConstantDecodeHook getDecodeHook() { return DecodeOpaqueTensorHook; } -}; +static bool init_hooks = ([] () { + TF::TensorFlowDialect::RegisterDecodeConstantHook(DecodeOpaqueTensorHook); +}(), true); } // anonymous namespace - -// Static initialization for TensorFlow dialect hooks registration. -static DialectHooksRegistration tf_hooks_registration("tf"); - } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 4008e8d33c6..17818302a1d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -1171,10 +1172,11 @@ LogicalResult ShapeInference::TryToFold(Operation* op) { if (!dialect) return failure(); // Only attempt TF dialect fallback if there are no unknown operands. if (some_unknown && dialect == tf_dialect_) return failure(); - SmallVector constants; - if (failed(dialect->constantFoldHook(op, constant_operands, constants))) + auto* interface = dialect->getRegisteredInterface(); + if (!interface) return failure(); + + if (failed(interface->Fold(op, constant_operands, fold_results))) return failure(); - fold_results.assign(constants.begin(), constants.end()); } for (auto result : zip(op->getResults(), fold_results)) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 242f3c6ceb7..36566d6c25f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1640,7 +1640,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( if (current_size_fragment >= vector_register_size_in_elements) { auto vector_type = llvm::VectorType::get( - element_ir_type, vector_register_size_in_elements); + element_ir_type, vector_register_size_in_elements, false); sharded_vector_type.insert( sharded_vector_type.end(), current_size_fragment / vector_register_size_in_elements, @@ -1656,7 +1656,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( // of two are all legal vector sizes (or at least can be lowered easily by // LLVM). sharded_vector_type.push_back( - llvm::VectorType::get(element_ir_type, current_size_fragment)); + llvm::VectorType::get(element_ir_type, current_size_fragment, false)); } return sharded_vector_type; } diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 8d9229c1223..3afdd9c163e 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -115,7 +115,7 @@ void RewriteCalls( // Upcast to vector type if input is a scalar. if (vector_width == 1) { - llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1); + llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1, false); input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input, uint64_t{0}); } @@ -264,8 +264,8 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, z = vsl.Add(one, z); // Convert n' to an i32. This is safe because we clamped it above. - llvm::Value* n_i32 = - b->CreateFPToSI(n, llvm::VectorType::get(b->getInt32Ty(), vector_width)); + llvm::Value* n_i32 = b->CreateFPToSI( + n, llvm::VectorType::get(b->getInt32Ty(), vector_width, false)); auto splat_i32 = [&](int32 v) { return b->CreateVectorSplat(vector_width, b->getInt32(v)); @@ -329,7 +329,7 @@ llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, llvm::Value* vector_constant_23 = b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b->getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width, false); llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type), vector_constant_23); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 0d2eab9fd42..48aa32f6b8f 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -33,7 +33,7 @@ VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, scalar_type_ = llvm_ir::PrimitiveTypeToIrType( primitive_type, b_->GetInsertBlock()->getModule()); scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); - vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); + vector_type_ = llvm::VectorType::get(scalar_type_, vector_size, false); vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); } @@ -155,7 +155,7 @@ llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits); if (vector) { - return llvm::VectorType::get(scalar_int_type, vector_size()); + return llvm::VectorType::get(scalar_int_type, vector_size(), false); } else { return scalar_int_type; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6309d7fcdee..9d4ec358bd3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -433,7 +433,7 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, builder->CreateZExt( builder->CreateBitCast(value, builder->getIntNTy(bit_width)), builder->getIntNTy(32 * num_segments)), - llvm::VectorType::get(builder->getInt32Ty(), num_segments)); + llvm::VectorType::get(builder->getInt32Ty(), num_segments, false)); for (int i = 0; i < num_segments; ++i) { llvm::Value* insert_val; if (target_triple.isNVPTX()) { diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 7b7c449a599..11cbfba0356 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -699,8 +699,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "88bbd30736561190a6733d0ad60aec21446b914c" - LLVM_SHA256 = "501fbe2f1e7ae7e8baede12f40866b954c4062852aa53b9ef414f852cfdbca4f" + LLVM_COMMIT = "0581c0b0eeba03da590d1176a4580cf9b9e8d1e3" + LLVM_SHA256 = "9d93364e8ecd080258a2d2a113383387b91e5f6f2b662b48897cde8c47c178b6" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index f92759709a2..0ee95ed7020 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -69,6 +69,8 @@ cc_library( "include/mlir/IR/*.h", ]) + [ "include/mlir/Interfaces/CallInterfaces.h", + "include/mlir/Interfaces/DecodeAttributesInterfaces.h", + "include/mlir/Interfaces/FoldInterfaces.h", ], includes = ["include"], deps = [