From 1ef42063180285108d0e4c45159ed56075255aff Mon Sep 17 00:00:00 2001 From: Alexander Belyaev <pifon@google.com> Date: Wed, 21 Oct 2020 07:06:15 -0700 Subject: [PATCH] [MLIR][KERNEL_GEN] Extend `tf_framework.alloc` to support input forwarding. This PR changes the op definition and the lowering to LLVM including `tf_framework_c_interface`. The lowering is tested with e2e test in `tf_framework_external_calls.mlir`. The old `tf_framework_external_calls.mlir` was renamed to `tf_framework_embed_and_call.mlir` because it tests two passes at once. PiperOrigin-RevId: 338256141 Change-Id: I4fd53ca30caae6064f7d6e15c36ca178f9ac8dee --- .../tools/kernel_gen/ir/tf_framework_ops.cc | 4 +- .../tools/kernel_gen/ir/tf_framework_ops.td | 27 +++--- .../kernel_gen/tests/embed_tf_framework.mlir | 4 +- .../mlir/tools/kernel_gen/tests/invalid.mlir | 2 +- .../mlir/tools/kernel_gen/tests/ops.mlir | 25 ++++-- .../tests/tf_framework_legalize_to_llvm.mlir | 37 +++++--- .../kernel_gen/tf_framework_c_interface.cc | 34 ++++++- .../kernel_gen/tf_framework_c_interface.h | 7 +- .../transforms/embed_tf_framework.cc | 14 +-- .../tf_framework_legalize_to_llvm.cc | 89 ++++++++++++++++--- 10 files changed, 179 insertions(+), 64 deletions(-) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index b3d92773be4..676e1849318 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -61,10 +61,10 @@ LogicalResult Verify(OpTy op) { } //===----------------------------------------------------------------------===// -// AllocRawOp +// TFAllocOp //===----------------------------------------------------------------------===// template <> -LogicalResult Verify<AllocRawOp>(AllocRawOp op) { +LogicalResult Verify<TFAllocOp>(TFAllocOp op) { // Check that the total number of operands matches the number of dynamic // dimensions specified in the memref type. unsigned result_dyn_dims = op.getType().getNumDynamicDims(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index e6e29bcbdc2..2f3e0f6f5fa 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -49,21 +49,28 @@ class TFFramework_Op<string mnemonic, list<OpTrait> traits = []> : } //===----------------------------------------------------------------------===// -// AllocRawOp +// TFAllocOp //===----------------------------------------------------------------------===// -def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw", +def TFFramework_TFAllocOp : TFFramework_Op<"alloc", [MemoryEffects<[MemAlloc<DefaultResource>]>]> { let summary = "allocation of tensors that uses TF Framework"; let description = [{ Allocation of tensors during kernel execution in the Compute method. - This should be used to allocate any temporary or output memref. - Corresponds to `Allocator::AllocateRaw` in - tensorflow/core/framework/allocator.h. + This should be used to allocate any temporary or output memref. If + `output_index` and `input_indices` are given, attempts to forward one of + the input tensors to the output by calling `OpKernelContext::forward_input`. + + If the attributes are missing or the forwarding fails, calls + `Allocator::AllocateRaw` in tensorflow/core/framework/allocator.h. }]; - let arguments = (ins TFFramework_OpKernelContextType:$ctx, - Variadic<Index>:$dyn_sizes); + let arguments = (ins + TFFramework_OpKernelContextType:$ctx, + Variadic<Index>:$dyn_sizes, + OptionalAttr<I32ArrayAttr>:$input_indices, + OptionalAttr<I32Attr>:$output_index + ); let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$result); let builders = [ @@ -92,16 +99,16 @@ def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw", } //===----------------------------------------------------------------------===// -// DeallocRawOp +// TFDeallocOp //===----------------------------------------------------------------------===// -def TFFramework_DeallocRawOp : TFFramework_Op<"dealloc_raw", +def TFFramework_TFDeallocOp : TFFramework_Op<"dealloc", [MemoryEffects<[MemFree]>]> { let summary = "deallocation of tensors that uses TF Framework"; let description = [{ Deallocation of tensors during kernel execution in the Compute method. This should be used to deallocate any temporary memref that was allocated - with `tf_framework.alloc_raw`. + with `tf_framework.alloc`. Corresponds to `Allocator::DeallocateRaw` in tensorflow/core/framework/allocator.h. }]; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir index bb0f1926cda..5d0beb7c7fe 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir @@ -10,9 +10,9 @@ func @tf_entry(%size_0 : index , %size_2 : index) -> index dealloc %buf : memref<?x10x?xf32> std.return %size_0 : index } -// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc_raw +// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc // CHECK-SAME: ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref<?x10x?xf32> -// CHECK-NEXT: tf_framework.dealloc_raw([[CTX]], [[VAL_3]]) : memref<?x10x?xf32> +// CHECK-NEXT: tf_framework.dealloc([[CTX]], [[VAL_3]]) : memref<?x10x?xf32> // CHECK-NEXT: return [[SIZE_0]] : index // ----- diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir index 1d1b3319515..1d3d5e485fb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir @@ -2,6 +2,6 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) { // expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}} - %buf = tf_framework.alloc_raw(%ctx, %size) : memref<?x10x?xi8> + %buf = tf_framework.alloc(%ctx, %size) : memref<?x10x?xi8> return } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir index fc8e7c97ec8..aa291c4c439 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir @@ -4,17 +4,28 @@ // Verify the generic form can be parsed. // RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s -// CHECK-LABEL: func @alloc_raw -func @alloc_raw(%ctx: !tf_framework.op_kernel_context, +// CHECK-LABEL: func @alloc +func @alloc(%ctx: !tf_framework.op_kernel_context, %size_0 : index , %size_2 : index) { - %buf_0 = tf_framework.alloc_raw(%ctx) : memref<10xi8> - %buf_1 = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xi8> + %buf_0 = tf_framework.alloc(%ctx) : memref<10xi8> + %buf_1 = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xi8> return } -// CHECK-LABEL: func @dealloc_raw -func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, %memref : memref<?x10xf32>) { - tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32> +// CHECK-LABEL: func @forwarding_alloc +func @forwarding_alloc(%ctx: !tf_framework.op_kernel_context, + %size_0 : index , %size_2 : index) { + %buf = tf_framework.alloc(%ctx, %size_0, %size_2) { + input_indices = [0 : i32, 1 : i32], + output_index = 0 : i32 + } : memref<?x10x?xi8> + return +} + +// CHECK-LABEL: func @dealloc +func @dealloc(%ctx: !tf_framework.op_kernel_context, + %memref : memref<?x10xf32>) { + tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32> return } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir index b943321e95b..44f8297a99f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir @@ -1,15 +1,15 @@ -// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s +// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file --print-ir-after-all | FileCheck %s -// CHECK: llvm.func @_mlir_ciface_tf_alloc_raw -// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8> +// CHECK: llvm.func @_mlir_ciface_tf_alloc +// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>) -> !llvm.ptr<i8> -// CHECK-LABEL: llvm.func @alloc_raw( +// CHECK-LABEL: llvm.func @alloc( // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>, // CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64, // CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] { -func @alloc_raw(%ctx: !tf_framework.op_kernel_context, +func @alloc(%ctx: !tf_framework.op_kernel_context, %size_0 : index , %size_2 : index) -> memref<?x10x?xf32> { - %buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xf32> + %buf = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xf32> std.return %buf : memref<?x10x?xf32> } // Compute number of elements. @@ -25,10 +25,19 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context, // CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]] // CHECK-SAME: !llvm.ptr<float> to !llvm.i64 -// Allocate memory. +// Compute total size in bytes. // CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]] -// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]]) -// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8> + +// Compute output index (-1) and candidate indices (0, NULL). +// CHECK: [[OUTPUT_INDEX:%.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32 +// CHECK-NEXT: [[NUM_CANDIDATES:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK-NEXT: [[CANDIDATES_PTR:%.*]] = llvm.mlir.null : !llvm.ptr<i32> + +// Allocate memory. +// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]], +// CHECK-SAME: [[OUTPUT_INDEX]], [[NUM_CANDIDATES]], [[CANDIDATES_PTR]]) +// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32> +// CHECK-SAME: ) -> !llvm.ptr<i8> // Build memref descriptor. // CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]] @@ -55,13 +64,13 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context, // ----- -// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr<i8>, !llvm.ptr<i8>) +// CHECK: llvm.func @_mlir_ciface_tf_dealloc(!llvm.ptr<i8>, !llvm.ptr<i8>) -// CHECK-LABEL: llvm.func @dealloc_raw( +// CHECK-LABEL: llvm.func @dealloc( // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>, -func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, +func @dealloc(%ctx: !tf_framework.op_kernel_context, %memref : memref<?x10xf32>) { - tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32> + tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32> return } // Extract allocated ptr from the memref descriptor. @@ -71,5 +80,5 @@ func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, // CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<i8> // Deallocate. -// CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw( +// CHECK: llvm.call @_mlir_ciface_tf_dealloc( // CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> () diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc index e75db59d885..2b2625b4d59 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -24,23 +24,49 @@ namespace tf_framework { namespace { using tensorflow::Allocator; +using tensorflow::AllocatorAttributes; Allocator* GetAllocator(void* op_kernel_ctx) { auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx); // TODO(pifon): Figure out how to set AllocatorAttributes correctly. - tensorflow::AllocatorAttributes attrs; + AllocatorAttributes attrs; return ctx->get_allocator(attrs); } } // namespace -extern "C" void* _mlir_ciface_tf_alloc_raw(void* op_kernel_ctx, - size_t num_bytes) { +extern "C" void* _mlir_ciface_tf_alloc(void* op_kernel_ctx, size_t num_bytes, + int32_t output_index, + int32_t num_candidates, + int32_t* candidate_input_indices) { + auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx); + + if (output_index != -1) { + auto element_size = ctx->expected_output_dtype(output_index); + // Create a 1D shape, because the shapes don't have to match exactly for + // input forwarding. Only the number of elements must be the same. + tensorflow::TensorShape output_shape; + output_shape.AddDim(num_bytes / element_size); + + // Iterate over indices of all inputs that can potentially be used for + // forwarding. + for (int i = 0; i < num_candidates; ++i) { + // TODO(pifon): Expose fetching AllocatorAttributes with the output_index. + AllocatorAttributes output_attr; + auto tensor = ctx->forward_input( + candidate_input_indices[i], output_index, element_size, output_shape, + ctx->output_memory_type(output_index), output_attr); + if (tensor != nullptr) { + return tensor->data(); + } + } + } + // If no forwarding happened, allocate a chunk of memory. return GetAllocator(op_kernel_ctx) ->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes); } -extern "C" void _mlir_ciface_tf_dealloc_raw(void* op_kernel_ctx, void* ptr) { +extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) { GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h index 143ebc95932..bf45116f372 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h @@ -22,10 +22,11 @@ namespace mlir { namespace kernel_gen { namespace tf_framework { -extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc_raw( - void* op_kernel_ctx, size_t num_bytes); +extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc( + void* op_kernel_ctx, size_t num_bytes, int32_t output_index, + int32_t num_candidates, int32_t* candidate_input_indices); -extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc_raw( +extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc( void* op_kernel_ctx, void* ptr); } // namespace tf_framework diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc index aa02aefa9d2..3b006c954cf 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc @@ -75,19 +75,19 @@ class AllocOpConverter : public OpConversionPattern<AllocOp> { return failure(); } // Symbolic operands that bind to the symbols of the memref's layout map are - // not supported by AllocRawOp. + // not supported by TFAllocOp. if (alloc.getNumSymbolicOperands() != 0) { return failure(); } - rewriter.replaceOpWithNewOp<AllocRawOp>(alloc, alloc.getType(), ctx, - operands); + rewriter.replaceOpWithNewOp<TFAllocOp>(alloc, alloc.getType(), ctx, + operands); return success(); } }; // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType // arg of the parent function. -class DeallocOpConverter : public OpConversionPattern<DeallocOp> { +class TFDeallocOpConverter : public OpConversionPattern<DeallocOp> { public: using OpConversionPattern<DeallocOp>::OpConversionPattern; @@ -108,8 +108,8 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> { return failure(); } DeallocOp::Adaptor transformed(operands); - rewriter.replaceOpWithNewOp<DeallocRawOp>(dealloc, ctx, - transformed.memref()); + rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, ctx, + transformed.memref()); return success(); } }; @@ -118,7 +118,7 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> { void PopulateEmbedTFFrameworkConversionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert<AllocOpConverter, DeallocOpConverter, FuncOpConverter>( + patterns->insert<AllocOpConverter, TFDeallocOpConverter, FuncOpConverter>( context); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index 431919c2de7..959f7ecf635 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project @@ -30,8 +31,8 @@ namespace { using LLVM::LLVMFuncOp; using LLVM::LLVMType; -static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc_raw"; -static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc_raw"; +static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc"; +static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc"; /// Base class for patterns converting TF Framework ops to function calls. template <typename OpTy> @@ -60,18 +61,18 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> { virtual LLVMType GetFuncType() const = 0; }; -class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> { +class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> { public: - using ConvertToLLVMCallOpPattern<AllocRawOp>::ConvertToLLVMCallOpPattern; + using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern; LogicalResult matchAndRewrite( Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - AllocRawOp alloc_raw_op = cast<AllocRawOp>(op); - AllocRawOp::Adaptor transformed(operands); + TFAllocOp tf_alloc_op = cast<TFAllocOp>(op); + TFAllocOp::Adaptor transformed(operands); - MemRefType memref_type = alloc_raw_op.getType(); + MemRefType memref_type = tf_alloc_op.getType(); // Get memref descriptor sizes. SmallVector<Value, 4> sizes; @@ -82,13 +83,28 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> { Value num_bytes = getCumulativeSizeInBytes( loc, memref_type.getElementType(), sizes, rewriter); + // Convert `output_index` or set it to -1 if the attribute is missing. + LLVM::LLVMType llvmInt32Type = + LLVM::LLVMType::getInt32Ty(rewriter.getContext()); + Value output_index = rewriter.create<LLVM::ConstantOp>( + loc, llvmInt32Type, + rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue() + ? tf_alloc_op.output_index().getValue() + : -1)); + + // Convert `candidate_input_indices`. + auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray( + loc, tf_alloc_op.input_indices(), &rewriter); + // Insert function call. FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op); Value allocated_byte_ptr = rewriter .create<LLVM::CallOp>( loc, getVoidPtrType(), tf_func_ref, - llvm::makeArrayRef({transformed.ctx(), num_bytes})) + llvm::makeArrayRef({transformed.ctx(), num_bytes, output_index, + candidates_count_and_ptr.first, + candidates_count_and_ptr.second})) .getResult(0); MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor( @@ -103,10 +119,18 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> { StringRef GetFuncName() const override { return kCInterfaceAlloc; } LLVMType GetFuncType() const override { + LLVMType llvm_i32_type = + LLVM::LLVMType::getInt32Ty(getDialect().getContext()); + LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo(); LLVMType llvm_void_ptr_type = getVoidPtrType(); - return LLVM::LLVMType::getFunctionTy( + return LLVMType::getFunctionTy( llvm_void_ptr_type, - llvm::makeArrayRef({llvm_void_ptr_type, getIndexType()}), + llvm::makeArrayRef( + {/*void* op_kernel_ctx*/ llvm_void_ptr_type, + /*size_t num_bytes*/ getIndexType(), + /*int32_t output_index*/ llvm_i32_type, + /*int32_t num_candidates*/ llvm_i32_type, + /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}), /*isVarArg=*/false); } @@ -144,16 +168,53 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> { } return memref_desc; } + + std::pair<Value, Value> ConvertI32ArrayAttrToStackAllocatedArray( + Location loc, llvm::Optional<ArrayAttr> attr, + ConversionPatternRewriter *rewriter) const { + LLVMType llvm_i32_type = + LLVM::LLVMType::getInt32Ty(getDialect().getContext()); + LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo(); + + // If the attribute is missing or empty, set the element count to 0 and + // return NULL. + if (!attr.hasValue() || attr.getValue().empty()) { + Value zero = rewriter->create<LLVM::ConstantOp>( + loc, llvm_i32_type, rewriter->getI32IntegerAttr(0)); + Value null_ptr = rewriter->create<LLVM::NullOp>(loc, llvm_i32_ptr_type); + return std::make_pair(zero, null_ptr); + } + + // Allocate array to store the elements. + auto &array_attr = attr.getValue(); + Value array_size = rewriter->create<LLVM::ConstantOp>( + loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size())); + Value array_ptr = rewriter->create<LLVM::AllocaOp>( + loc, llvm_i32_ptr_type, array_size, /*alignment=*/0); + + for (auto &dim : llvm::enumerate(array_attr)) { + Value index = rewriter->create<LLVM::ConstantOp>( + loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index())); + Value elem_ptr = rewriter->create<LLVM::GEPOp>(loc, llvm_i32_ptr_type, + array_ptr, index); + Value elem = rewriter->create<LLVM::ConstantOp>( + loc, llvm_i32_type, + rewriter->getI32IntegerAttr( + dim.value().cast<IntegerAttr>().getInt())); + rewriter->create<LLVM::StoreOp>(loc, elem, elem_ptr); + } + return std::make_pair(array_size, array_ptr); + } }; -class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern<DeallocRawOp> { +class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> { public: - using ConvertToLLVMCallOpPattern<DeallocRawOp>::ConvertToLLVMCallOpPattern; + using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern; LogicalResult matchAndRewrite( Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { - DeallocRawOp::Adaptor transformed(operands); + TFDeallocOp::Adaptor transformed(operands); MemRefDescriptor memref(transformed.memref()); Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>( @@ -194,7 +255,7 @@ class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> { void PopulateTFFrameworkToLLVMConversionPatterns( LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { patterns->insert<NullContextOpConverter>(*converter); - patterns->insert<AllocRawOpConverter, DeallocRawOpConverter>(*converter); + patterns->insert<TFAllocOpConverter, TFDeallocOpConverter>(*converter); } } // namespace tf_framework