[MLIR][KERNEL_GEN] Extend tf_framework.alloc to support input forwarding.

This PR changes the op definition and the lowering to LLVM including `tf_framework_c_interface`.
The lowering is tested with e2e test in `tf_framework_external_calls.mlir`.
The old `tf_framework_external_calls.mlir` was renamed to `tf_framework_embed_and_call.mlir` because it tests two passes at once.

PiperOrigin-RevId: 338256141
Change-Id: I4fd53ca30caae6064f7d6e15c36ca178f9ac8dee
This commit is contained in:
Alexander Belyaev 2020-10-21 07:06:15 -07:00 committed by TensorFlower Gardener
parent b39bc5f1f0
commit 1ef4206318
10 changed files with 179 additions and 64 deletions

View File

@ -61,10 +61,10 @@ LogicalResult Verify(OpTy op) {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AllocRawOp // TFAllocOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
template <> template <>
LogicalResult Verify<AllocRawOp>(AllocRawOp op) { LogicalResult Verify<TFAllocOp>(TFAllocOp op) {
// Check that the total number of operands matches the number of dynamic // Check that the total number of operands matches the number of dynamic
// dimensions specified in the memref type. // dimensions specified in the memref type.
unsigned result_dyn_dims = op.getType().getNumDynamicDims(); unsigned result_dyn_dims = op.getType().getNumDynamicDims();

View File

@ -49,21 +49,28 @@ class TFFramework_Op<string mnemonic, list<OpTrait> traits = []> :
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AllocRawOp // TFAllocOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw", def TFFramework_TFAllocOp : TFFramework_Op<"alloc",
[MemoryEffects<[MemAlloc<DefaultResource>]>]> { [MemoryEffects<[MemAlloc<DefaultResource>]>]> {
let summary = "allocation of tensors that uses TF Framework"; let summary = "allocation of tensors that uses TF Framework";
let description = [{ let description = [{
Allocation of tensors during kernel execution in the Compute method. Allocation of tensors during kernel execution in the Compute method.
This should be used to allocate any temporary or output memref. This should be used to allocate any temporary or output memref. If
Corresponds to `Allocator::AllocateRaw` in `output_index` and `input_indices` are given, attempts to forward one of
tensorflow/core/framework/allocator.h. the input tensors to the output by calling `OpKernelContext::forward_input`.
If the attributes are missing or the forwarding fails, calls
`Allocator::AllocateRaw` in tensorflow/core/framework/allocator.h.
}]; }];
let arguments = (ins TFFramework_OpKernelContextType:$ctx, let arguments = (ins
Variadic<Index>:$dyn_sizes); TFFramework_OpKernelContextType:$ctx,
Variadic<Index>:$dyn_sizes,
OptionalAttr<I32ArrayAttr>:$input_indices,
OptionalAttr<I32Attr>:$output_index
);
let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$result); let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$result);
let builders = [ let builders = [
@ -92,16 +99,16 @@ def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw",
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DeallocRawOp // TFDeallocOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def TFFramework_DeallocRawOp : TFFramework_Op<"dealloc_raw", def TFFramework_TFDeallocOp : TFFramework_Op<"dealloc",
[MemoryEffects<[MemFree]>]> { [MemoryEffects<[MemFree]>]> {
let summary = "deallocation of tensors that uses TF Framework"; let summary = "deallocation of tensors that uses TF Framework";
let description = [{ let description = [{
Deallocation of tensors during kernel execution in the Compute method. Deallocation of tensors during kernel execution in the Compute method.
This should be used to deallocate any temporary memref that was allocated This should be used to deallocate any temporary memref that was allocated
with `tf_framework.alloc_raw`. with `tf_framework.alloc`.
Corresponds to `Allocator::DeallocateRaw` in Corresponds to `Allocator::DeallocateRaw` in
tensorflow/core/framework/allocator.h. tensorflow/core/framework/allocator.h.
}]; }];

View File

@ -10,9 +10,9 @@ func @tf_entry(%size_0 : index , %size_2 : index) -> index
dealloc %buf : memref<?x10x?xf32> dealloc %buf : memref<?x10x?xf32>
std.return %size_0 : index std.return %size_0 : index
} }
// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc_raw // CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc
// CHECK-SAME: ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref<?x10x?xf32> // CHECK-SAME: ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref<?x10x?xf32>
// CHECK-NEXT: tf_framework.dealloc_raw([[CTX]], [[VAL_3]]) : memref<?x10x?xf32> // CHECK-NEXT: tf_framework.dealloc([[CTX]], [[VAL_3]]) : memref<?x10x?xf32>
// CHECK-NEXT: return [[SIZE_0]] : index // CHECK-NEXT: return [[SIZE_0]] : index
// ----- // -----

View File

@ -2,6 +2,6 @@
func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) { func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) {
// expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}} // expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}}
%buf = tf_framework.alloc_raw(%ctx, %size) : memref<?x10x?xi8> %buf = tf_framework.alloc(%ctx, %size) : memref<?x10x?xi8>
return return
} }

View File

@ -4,17 +4,28 @@
// Verify the generic form can be parsed. // Verify the generic form can be parsed.
// RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s // RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s
// CHECK-LABEL: func @alloc_raw // CHECK-LABEL: func @alloc
func @alloc_raw(%ctx: !tf_framework.op_kernel_context, func @alloc(%ctx: !tf_framework.op_kernel_context,
%size_0 : index , %size_2 : index) { %size_0 : index , %size_2 : index) {
%buf_0 = tf_framework.alloc_raw(%ctx) : memref<10xi8> %buf_0 = tf_framework.alloc(%ctx) : memref<10xi8>
%buf_1 = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xi8> %buf_1 = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
return return
} }
// CHECK-LABEL: func @dealloc_raw // CHECK-LABEL: func @forwarding_alloc
func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, %memref : memref<?x10xf32>) { func @forwarding_alloc(%ctx: !tf_framework.op_kernel_context,
tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32> %size_0 : index , %size_2 : index) {
%buf = tf_framework.alloc(%ctx, %size_0, %size_2) {
input_indices = [0 : i32, 1 : i32],
output_index = 0 : i32
} : memref<?x10x?xi8>
return
}
// CHECK-LABEL: func @dealloc
func @dealloc(%ctx: !tf_framework.op_kernel_context,
%memref : memref<?x10xf32>) {
tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32>
return return
} }

View File

@ -1,15 +1,15 @@
// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s // RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file --print-ir-after-all | FileCheck %s
// CHECK: llvm.func @_mlir_ciface_tf_alloc_raw // CHECK: llvm.func @_mlir_ciface_tf_alloc
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8> // CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>) -> !llvm.ptr<i8>
// CHECK-LABEL: llvm.func @alloc_raw( // CHECK-LABEL: llvm.func @alloc(
// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>, // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>,
// CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64, // CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64,
// CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] { // CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] {
func @alloc_raw(%ctx: !tf_framework.op_kernel_context, func @alloc(%ctx: !tf_framework.op_kernel_context,
%size_0 : index , %size_2 : index) -> memref<?x10x?xf32> { %size_0 : index , %size_2 : index) -> memref<?x10x?xf32> {
%buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xf32> %buf = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xf32>
std.return %buf : memref<?x10x?xf32> std.return %buf : memref<?x10x?xf32>
} }
// Compute number of elements. // Compute number of elements.
@ -25,10 +25,19 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
// CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]] // CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]]
// CHECK-SAME: !llvm.ptr<float> to !llvm.i64 // CHECK-SAME: !llvm.ptr<float> to !llvm.i64
// Allocate memory. // Compute total size in bytes.
// CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]] // CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]]
// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]])
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8> // Compute output index (-1) and candidate indices (0, NULL).
// CHECK: [[OUTPUT_INDEX:%.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
// CHECK-NEXT: [[NUM_CANDIDATES:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK-NEXT: [[CANDIDATES_PTR:%.*]] = llvm.mlir.null : !llvm.ptr<i32>
// Allocate memory.
// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]],
// CHECK-SAME: [[OUTPUT_INDEX]], [[NUM_CANDIDATES]], [[CANDIDATES_PTR]])
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>
// CHECK-SAME: ) -> !llvm.ptr<i8>
// Build memref descriptor. // Build memref descriptor.
// CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]] // CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]]
@ -55,13 +64,13 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
// ----- // -----
// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr<i8>, !llvm.ptr<i8>) // CHECK: llvm.func @_mlir_ciface_tf_dealloc(!llvm.ptr<i8>, !llvm.ptr<i8>)
// CHECK-LABEL: llvm.func @dealloc_raw( // CHECK-LABEL: llvm.func @dealloc(
// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>, // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>,
func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, func @dealloc(%ctx: !tf_framework.op_kernel_context,
%memref : memref<?x10xf32>) { %memref : memref<?x10xf32>) {
tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32> tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32>
return return
} }
// Extract allocated ptr from the memref descriptor. // Extract allocated ptr from the memref descriptor.
@ -71,5 +80,5 @@ func @dealloc_raw(%ctx: !tf_framework.op_kernel_context,
// CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<i8> // CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<i8>
// Deallocate. // Deallocate.
// CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw( // CHECK: llvm.call @_mlir_ciface_tf_dealloc(
// CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> () // CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()

View File

@ -24,23 +24,49 @@ namespace tf_framework {
namespace { namespace {
using tensorflow::Allocator; using tensorflow::Allocator;
using tensorflow::AllocatorAttributes;
Allocator* GetAllocator(void* op_kernel_ctx) { Allocator* GetAllocator(void* op_kernel_ctx) {
auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx); auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
// TODO(pifon): Figure out how to set AllocatorAttributes correctly. // TODO(pifon): Figure out how to set AllocatorAttributes correctly.
tensorflow::AllocatorAttributes attrs; AllocatorAttributes attrs;
return ctx->get_allocator(attrs); return ctx->get_allocator(attrs);
} }
} // namespace } // namespace
extern "C" void* _mlir_ciface_tf_alloc_raw(void* op_kernel_ctx, extern "C" void* _mlir_ciface_tf_alloc(void* op_kernel_ctx, size_t num_bytes,
size_t num_bytes) { int32_t output_index,
int32_t num_candidates,
int32_t* candidate_input_indices) {
auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
if (output_index != -1) {
auto element_size = ctx->expected_output_dtype(output_index);
// Create a 1D shape, because the shapes don't have to match exactly for
// input forwarding. Only the number of elements must be the same.
tensorflow::TensorShape output_shape;
output_shape.AddDim(num_bytes / element_size);
// Iterate over indices of all inputs that can potentially be used for
// forwarding.
for (int i = 0; i < num_candidates; ++i) {
// TODO(pifon): Expose fetching AllocatorAttributes with the output_index.
AllocatorAttributes output_attr;
auto tensor = ctx->forward_input(
candidate_input_indices[i], output_index, element_size, output_shape,
ctx->output_memory_type(output_index), output_attr);
if (tensor != nullptr) {
return tensor->data();
}
}
}
// If no forwarding happened, allocate a chunk of memory.
return GetAllocator(op_kernel_ctx) return GetAllocator(op_kernel_ctx)
->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes); ->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes);
} }
extern "C" void _mlir_ciface_tf_dealloc_raw(void* op_kernel_ctx, void* ptr) { extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr); GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr);
} }

View File

@ -22,10 +22,11 @@ namespace mlir {
namespace kernel_gen { namespace kernel_gen {
namespace tf_framework { namespace tf_framework {
extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc_raw( extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc(
void* op_kernel_ctx, size_t num_bytes); void* op_kernel_ctx, size_t num_bytes, int32_t output_index,
int32_t num_candidates, int32_t* candidate_input_indices);
extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc_raw( extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc(
void* op_kernel_ctx, void* ptr); void* op_kernel_ctx, void* ptr);
} // namespace tf_framework } // namespace tf_framework

View File

@ -75,11 +75,11 @@ class AllocOpConverter : public OpConversionPattern<AllocOp> {
return failure(); return failure();
} }
// Symbolic operands that bind to the symbols of the memref's layout map are // Symbolic operands that bind to the symbols of the memref's layout map are
// not supported by AllocRawOp. // not supported by TFAllocOp.
if (alloc.getNumSymbolicOperands() != 0) { if (alloc.getNumSymbolicOperands() != 0) {
return failure(); return failure();
} }
rewriter.replaceOpWithNewOp<AllocRawOp>(alloc, alloc.getType(), ctx, rewriter.replaceOpWithNewOp<TFAllocOp>(alloc, alloc.getType(), ctx,
operands); operands);
return success(); return success();
} }
@ -87,7 +87,7 @@ class AllocOpConverter : public OpConversionPattern<AllocOp> {
// Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
// arg of the parent function. // arg of the parent function.
class DeallocOpConverter : public OpConversionPattern<DeallocOp> { class TFDeallocOpConverter : public OpConversionPattern<DeallocOp> {
public: public:
using OpConversionPattern<DeallocOp>::OpConversionPattern; using OpConversionPattern<DeallocOp>::OpConversionPattern;
@ -108,7 +108,7 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
return failure(); return failure();
} }
DeallocOp::Adaptor transformed(operands); DeallocOp::Adaptor transformed(operands);
rewriter.replaceOpWithNewOp<DeallocRawOp>(dealloc, ctx, rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, ctx,
transformed.memref()); transformed.memref());
return success(); return success();
} }
@ -118,7 +118,7 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
void PopulateEmbedTFFrameworkConversionPatterns( void PopulateEmbedTFFrameworkConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) { MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<AllocOpConverter, DeallocOpConverter, FuncOpConverter>( patterns->insert<AllocOpConverter, TFDeallocOpConverter, FuncOpConverter>(
context); context);
} }

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
@ -30,8 +31,8 @@ namespace {
using LLVM::LLVMFuncOp; using LLVM::LLVMFuncOp;
using LLVM::LLVMType; using LLVM::LLVMType;
static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc_raw"; static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc_raw"; static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
/// Base class for patterns converting TF Framework ops to function calls. /// Base class for patterns converting TF Framework ops to function calls.
template <typename OpTy> template <typename OpTy>
@ -60,18 +61,18 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
virtual LLVMType GetFuncType() const = 0; virtual LLVMType GetFuncType() const = 0;
}; };
class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> { class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
public: public:
using ConvertToLLVMCallOpPattern<AllocRawOp>::ConvertToLLVMCallOpPattern; using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands, Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc(); Location loc = op->getLoc();
AllocRawOp alloc_raw_op = cast<AllocRawOp>(op); TFAllocOp tf_alloc_op = cast<TFAllocOp>(op);
AllocRawOp::Adaptor transformed(operands); TFAllocOp::Adaptor transformed(operands);
MemRefType memref_type = alloc_raw_op.getType(); MemRefType memref_type = tf_alloc_op.getType();
// Get memref descriptor sizes. // Get memref descriptor sizes.
SmallVector<Value, 4> sizes; SmallVector<Value, 4> sizes;
@ -82,13 +83,28 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
Value num_bytes = getCumulativeSizeInBytes( Value num_bytes = getCumulativeSizeInBytes(
loc, memref_type.getElementType(), sizes, rewriter); loc, memref_type.getElementType(), sizes, rewriter);
// Convert `output_index` or set it to -1 if the attribute is missing.
LLVM::LLVMType llvmInt32Type =
LLVM::LLVMType::getInt32Ty(rewriter.getContext());
Value output_index = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type,
rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
? tf_alloc_op.output_index().getValue()
: -1));
// Convert `candidate_input_indices`.
auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray(
loc, tf_alloc_op.input_indices(), &rewriter);
// Insert function call. // Insert function call.
FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op); FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
Value allocated_byte_ptr = Value allocated_byte_ptr =
rewriter rewriter
.create<LLVM::CallOp>( .create<LLVM::CallOp>(
loc, getVoidPtrType(), tf_func_ref, loc, getVoidPtrType(), tf_func_ref,
llvm::makeArrayRef({transformed.ctx(), num_bytes})) llvm::makeArrayRef({transformed.ctx(), num_bytes, output_index,
candidates_count_and_ptr.first,
candidates_count_and_ptr.second}))
.getResult(0); .getResult(0);
MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor( MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
@ -103,10 +119,18 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
StringRef GetFuncName() const override { return kCInterfaceAlloc; } StringRef GetFuncName() const override { return kCInterfaceAlloc; }
LLVMType GetFuncType() const override { LLVMType GetFuncType() const override {
LLVMType llvm_i32_type =
LLVM::LLVMType::getInt32Ty(getDialect().getContext());
LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
LLVMType llvm_void_ptr_type = getVoidPtrType(); LLVMType llvm_void_ptr_type = getVoidPtrType();
return LLVM::LLVMType::getFunctionTy( return LLVMType::getFunctionTy(
llvm_void_ptr_type, llvm_void_ptr_type,
llvm::makeArrayRef({llvm_void_ptr_type, getIndexType()}), llvm::makeArrayRef(
{/*void* op_kernel_ctx*/ llvm_void_ptr_type,
/*size_t num_bytes*/ getIndexType(),
/*int32_t output_index*/ llvm_i32_type,
/*int32_t num_candidates*/ llvm_i32_type,
/*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}),
/*isVarArg=*/false); /*isVarArg=*/false);
} }
@ -144,16 +168,53 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
} }
return memref_desc; return memref_desc;
} }
std::pair<Value, Value> ConvertI32ArrayAttrToStackAllocatedArray(
Location loc, llvm::Optional<ArrayAttr> attr,
ConversionPatternRewriter *rewriter) const {
LLVMType llvm_i32_type =
LLVM::LLVMType::getInt32Ty(getDialect().getContext());
LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
// If the attribute is missing or empty, set the element count to 0 and
// return NULL.
if (!attr.hasValue() || attr.getValue().empty()) {
Value zero = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(0));
Value null_ptr = rewriter->create<LLVM::NullOp>(loc, llvm_i32_ptr_type);
return std::make_pair(zero, null_ptr);
}
// Allocate array to store the elements.
auto &array_attr = attr.getValue();
Value array_size = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size()));
Value array_ptr = rewriter->create<LLVM::AllocaOp>(
loc, llvm_i32_ptr_type, array_size, /*alignment=*/0);
for (auto &dim : llvm::enumerate(array_attr)) {
Value index = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index()));
Value elem_ptr = rewriter->create<LLVM::GEPOp>(loc, llvm_i32_ptr_type,
array_ptr, index);
Value elem = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type,
rewriter->getI32IntegerAttr(
dim.value().cast<IntegerAttr>().getInt()));
rewriter->create<LLVM::StoreOp>(loc, elem, elem_ptr);
}
return std::make_pair(array_size, array_ptr);
}
}; };
class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern<DeallocRawOp> { class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
public: public:
using ConvertToLLVMCallOpPattern<DeallocRawOp>::ConvertToLLVMCallOpPattern; using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands, Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
DeallocRawOp::Adaptor transformed(operands); TFDeallocOp::Adaptor transformed(operands);
MemRefDescriptor memref(transformed.memref()); MemRefDescriptor memref(transformed.memref());
Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>( Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
@ -194,7 +255,7 @@ class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
void PopulateTFFrameworkToLLVMConversionPatterns( void PopulateTFFrameworkToLLVMConversionPatterns(
LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { LLVMTypeConverter *converter, OwningRewritePatternList *patterns) {
patterns->insert<NullContextOpConverter>(*converter); patterns->insert<NullContextOpConverter>(*converter);
patterns->insert<AllocRawOpConverter, DeallocRawOpConverter>(*converter); patterns->insert<TFAllocOpConverter, TFDeallocOpConverter>(*converter);
} }
} // namespace tf_framework } // namespace tf_framework