[MLIR][KERNEL_GEN] Extend tf_framework.alloc
to support input forwarding.
This PR changes the op definition and the lowering to LLVM including `tf_framework_c_interface`. The lowering is tested with e2e test in `tf_framework_external_calls.mlir`. The old `tf_framework_external_calls.mlir` was renamed to `tf_framework_embed_and_call.mlir` because it tests two passes at once. PiperOrigin-RevId: 338256141 Change-Id: I4fd53ca30caae6064f7d6e15c36ca178f9ac8dee
This commit is contained in:
parent
b39bc5f1f0
commit
1ef4206318
@ -61,10 +61,10 @@ LogicalResult Verify(OpTy op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AllocRawOp
|
// TFAllocOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
LogicalResult Verify<AllocRawOp>(AllocRawOp op) {
|
LogicalResult Verify<TFAllocOp>(TFAllocOp op) {
|
||||||
// Check that the total number of operands matches the number of dynamic
|
// Check that the total number of operands matches the number of dynamic
|
||||||
// dimensions specified in the memref type.
|
// dimensions specified in the memref type.
|
||||||
unsigned result_dyn_dims = op.getType().getNumDynamicDims();
|
unsigned result_dyn_dims = op.getType().getNumDynamicDims();
|
||||||
|
@ -49,21 +49,28 @@ class TFFramework_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AllocRawOp
|
// TFAllocOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw",
|
def TFFramework_TFAllocOp : TFFramework_Op<"alloc",
|
||||||
[MemoryEffects<[MemAlloc<DefaultResource>]>]> {
|
[MemoryEffects<[MemAlloc<DefaultResource>]>]> {
|
||||||
let summary = "allocation of tensors that uses TF Framework";
|
let summary = "allocation of tensors that uses TF Framework";
|
||||||
let description = [{
|
let description = [{
|
||||||
Allocation of tensors during kernel execution in the Compute method.
|
Allocation of tensors during kernel execution in the Compute method.
|
||||||
|
|
||||||
This should be used to allocate any temporary or output memref.
|
This should be used to allocate any temporary or output memref. If
|
||||||
Corresponds to `Allocator::AllocateRaw` in
|
`output_index` and `input_indices` are given, attempts to forward one of
|
||||||
tensorflow/core/framework/allocator.h.
|
the input tensors to the output by calling `OpKernelContext::forward_input`.
|
||||||
|
|
||||||
|
If the attributes are missing or the forwarding fails, calls
|
||||||
|
`Allocator::AllocateRaw` in tensorflow/core/framework/allocator.h.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TFFramework_OpKernelContextType:$ctx,
|
let arguments = (ins
|
||||||
Variadic<Index>:$dyn_sizes);
|
TFFramework_OpKernelContextType:$ctx,
|
||||||
|
Variadic<Index>:$dyn_sizes,
|
||||||
|
OptionalAttr<I32ArrayAttr>:$input_indices,
|
||||||
|
OptionalAttr<I32Attr>:$output_index
|
||||||
|
);
|
||||||
let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$result);
|
let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$result);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
@ -92,16 +99,16 @@ def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw",
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DeallocRawOp
|
// TFDeallocOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
def TFFramework_DeallocRawOp : TFFramework_Op<"dealloc_raw",
|
def TFFramework_TFDeallocOp : TFFramework_Op<"dealloc",
|
||||||
[MemoryEffects<[MemFree]>]> {
|
[MemoryEffects<[MemFree]>]> {
|
||||||
let summary = "deallocation of tensors that uses TF Framework";
|
let summary = "deallocation of tensors that uses TF Framework";
|
||||||
let description = [{
|
let description = [{
|
||||||
Deallocation of tensors during kernel execution in the Compute method.
|
Deallocation of tensors during kernel execution in the Compute method.
|
||||||
|
|
||||||
This should be used to deallocate any temporary memref that was allocated
|
This should be used to deallocate any temporary memref that was allocated
|
||||||
with `tf_framework.alloc_raw`.
|
with `tf_framework.alloc`.
|
||||||
Corresponds to `Allocator::DeallocateRaw` in
|
Corresponds to `Allocator::DeallocateRaw` in
|
||||||
tensorflow/core/framework/allocator.h.
|
tensorflow/core/framework/allocator.h.
|
||||||
}];
|
}];
|
||||||
|
@ -10,9 +10,9 @@ func @tf_entry(%size_0 : index , %size_2 : index) -> index
|
|||||||
dealloc %buf : memref<?x10x?xf32>
|
dealloc %buf : memref<?x10x?xf32>
|
||||||
std.return %size_0 : index
|
std.return %size_0 : index
|
||||||
}
|
}
|
||||||
// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc_raw
|
// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc
|
||||||
// CHECK-SAME: ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref<?x10x?xf32>
|
// CHECK-SAME: ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref<?x10x?xf32>
|
||||||
// CHECK-NEXT: tf_framework.dealloc_raw([[CTX]], [[VAL_3]]) : memref<?x10x?xf32>
|
// CHECK-NEXT: tf_framework.dealloc([[CTX]], [[VAL_3]]) : memref<?x10x?xf32>
|
||||||
// CHECK-NEXT: return [[SIZE_0]] : index
|
// CHECK-NEXT: return [[SIZE_0]] : index
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -2,6 +2,6 @@
|
|||||||
|
|
||||||
func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) {
|
func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) {
|
||||||
// expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}}
|
// expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}}
|
||||||
%buf = tf_framework.alloc_raw(%ctx, %size) : memref<?x10x?xi8>
|
%buf = tf_framework.alloc(%ctx, %size) : memref<?x10x?xi8>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -4,17 +4,28 @@
|
|||||||
// Verify the generic form can be parsed.
|
// Verify the generic form can be parsed.
|
||||||
// RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s
|
// RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @alloc_raw
|
// CHECK-LABEL: func @alloc
|
||||||
func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
|
func @alloc(%ctx: !tf_framework.op_kernel_context,
|
||||||
%size_0 : index , %size_2 : index) {
|
%size_0 : index , %size_2 : index) {
|
||||||
%buf_0 = tf_framework.alloc_raw(%ctx) : memref<10xi8>
|
%buf_0 = tf_framework.alloc(%ctx) : memref<10xi8>
|
||||||
%buf_1 = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
|
%buf_1 = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @dealloc_raw
|
// CHECK-LABEL: func @forwarding_alloc
|
||||||
func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, %memref : memref<?x10xf32>) {
|
func @forwarding_alloc(%ctx: !tf_framework.op_kernel_context,
|
||||||
tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32>
|
%size_0 : index , %size_2 : index) {
|
||||||
|
%buf = tf_framework.alloc(%ctx, %size_0, %size_2) {
|
||||||
|
input_indices = [0 : i32, 1 : i32],
|
||||||
|
output_index = 0 : i32
|
||||||
|
} : memref<?x10x?xi8>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dealloc
|
||||||
|
func @dealloc(%ctx: !tf_framework.op_kernel_context,
|
||||||
|
%memref : memref<?x10xf32>) {
|
||||||
|
tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s
|
// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file --print-ir-after-all | FileCheck %s
|
||||||
|
|
||||||
// CHECK: llvm.func @_mlir_ciface_tf_alloc_raw
|
// CHECK: llvm.func @_mlir_ciface_tf_alloc
|
||||||
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
|
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>) -> !llvm.ptr<i8>
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @alloc_raw(
|
// CHECK-LABEL: llvm.func @alloc(
|
||||||
// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>,
|
// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>,
|
||||||
// CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64,
|
// CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64,
|
||||||
// CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] {
|
// CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] {
|
||||||
func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
|
func @alloc(%ctx: !tf_framework.op_kernel_context,
|
||||||
%size_0 : index , %size_2 : index) -> memref<?x10x?xf32> {
|
%size_0 : index , %size_2 : index) -> memref<?x10x?xf32> {
|
||||||
%buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xf32>
|
%buf = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xf32>
|
||||||
std.return %buf : memref<?x10x?xf32>
|
std.return %buf : memref<?x10x?xf32>
|
||||||
}
|
}
|
||||||
// Compute number of elements.
|
// Compute number of elements.
|
||||||
@ -25,10 +25,19 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
|
|||||||
// CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]]
|
// CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]]
|
||||||
// CHECK-SAME: !llvm.ptr<float> to !llvm.i64
|
// CHECK-SAME: !llvm.ptr<float> to !llvm.i64
|
||||||
|
|
||||||
// Allocate memory.
|
// Compute total size in bytes.
|
||||||
// CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]]
|
// CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]]
|
||||||
// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]])
|
|
||||||
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
|
// Compute output index (-1) and candidate indices (0, NULL).
|
||||||
|
// CHECK: [[OUTPUT_INDEX:%.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
|
||||||
|
// CHECK-NEXT: [[NUM_CANDIDATES:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
// CHECK-NEXT: [[CANDIDATES_PTR:%.*]] = llvm.mlir.null : !llvm.ptr<i32>
|
||||||
|
|
||||||
|
// Allocate memory.
|
||||||
|
// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]],
|
||||||
|
// CHECK-SAME: [[OUTPUT_INDEX]], [[NUM_CANDIDATES]], [[CANDIDATES_PTR]])
|
||||||
|
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>
|
||||||
|
// CHECK-SAME: ) -> !llvm.ptr<i8>
|
||||||
|
|
||||||
// Build memref descriptor.
|
// Build memref descriptor.
|
||||||
// CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]]
|
// CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]]
|
||||||
@ -55,13 +64,13 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr<i8>, !llvm.ptr<i8>)
|
// CHECK: llvm.func @_mlir_ciface_tf_dealloc(!llvm.ptr<i8>, !llvm.ptr<i8>)
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @dealloc_raw(
|
// CHECK-LABEL: llvm.func @dealloc(
|
||||||
// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>,
|
// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>,
|
||||||
func @dealloc_raw(%ctx: !tf_framework.op_kernel_context,
|
func @dealloc(%ctx: !tf_framework.op_kernel_context,
|
||||||
%memref : memref<?x10xf32>) {
|
%memref : memref<?x10xf32>) {
|
||||||
tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32>
|
tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Extract allocated ptr from the memref descriptor.
|
// Extract allocated ptr from the memref descriptor.
|
||||||
@ -71,5 +80,5 @@ func @dealloc_raw(%ctx: !tf_framework.op_kernel_context,
|
|||||||
// CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<i8>
|
// CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<i8>
|
||||||
|
|
||||||
// Deallocate.
|
// Deallocate.
|
||||||
// CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw(
|
// CHECK: llvm.call @_mlir_ciface_tf_dealloc(
|
||||||
// CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
|
// CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
|
||||||
|
@ -24,23 +24,49 @@ namespace tf_framework {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using tensorflow::Allocator;
|
using tensorflow::Allocator;
|
||||||
|
using tensorflow::AllocatorAttributes;
|
||||||
|
|
||||||
Allocator* GetAllocator(void* op_kernel_ctx) {
|
Allocator* GetAllocator(void* op_kernel_ctx) {
|
||||||
auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
|
auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
|
||||||
// TODO(pifon): Figure out how to set AllocatorAttributes correctly.
|
// TODO(pifon): Figure out how to set AllocatorAttributes correctly.
|
||||||
tensorflow::AllocatorAttributes attrs;
|
AllocatorAttributes attrs;
|
||||||
return ctx->get_allocator(attrs);
|
return ctx->get_allocator(attrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
extern "C" void* _mlir_ciface_tf_alloc_raw(void* op_kernel_ctx,
|
extern "C" void* _mlir_ciface_tf_alloc(void* op_kernel_ctx, size_t num_bytes,
|
||||||
size_t num_bytes) {
|
int32_t output_index,
|
||||||
|
int32_t num_candidates,
|
||||||
|
int32_t* candidate_input_indices) {
|
||||||
|
auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
|
||||||
|
|
||||||
|
if (output_index != -1) {
|
||||||
|
auto element_size = ctx->expected_output_dtype(output_index);
|
||||||
|
// Create a 1D shape, because the shapes don't have to match exactly for
|
||||||
|
// input forwarding. Only the number of elements must be the same.
|
||||||
|
tensorflow::TensorShape output_shape;
|
||||||
|
output_shape.AddDim(num_bytes / element_size);
|
||||||
|
|
||||||
|
// Iterate over indices of all inputs that can potentially be used for
|
||||||
|
// forwarding.
|
||||||
|
for (int i = 0; i < num_candidates; ++i) {
|
||||||
|
// TODO(pifon): Expose fetching AllocatorAttributes with the output_index.
|
||||||
|
AllocatorAttributes output_attr;
|
||||||
|
auto tensor = ctx->forward_input(
|
||||||
|
candidate_input_indices[i], output_index, element_size, output_shape,
|
||||||
|
ctx->output_memory_type(output_index), output_attr);
|
||||||
|
if (tensor != nullptr) {
|
||||||
|
return tensor->data();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If no forwarding happened, allocate a chunk of memory.
|
||||||
return GetAllocator(op_kernel_ctx)
|
return GetAllocator(op_kernel_ctx)
|
||||||
->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes);
|
->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void _mlir_ciface_tf_dealloc_raw(void* op_kernel_ctx, void* ptr) {
|
extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
|
||||||
GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr);
|
GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,10 +22,11 @@ namespace mlir {
|
|||||||
namespace kernel_gen {
|
namespace kernel_gen {
|
||||||
namespace tf_framework {
|
namespace tf_framework {
|
||||||
|
|
||||||
extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc_raw(
|
extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc(
|
||||||
void* op_kernel_ctx, size_t num_bytes);
|
void* op_kernel_ctx, size_t num_bytes, int32_t output_index,
|
||||||
|
int32_t num_candidates, int32_t* candidate_input_indices);
|
||||||
|
|
||||||
extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc_raw(
|
extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc(
|
||||||
void* op_kernel_ctx, void* ptr);
|
void* op_kernel_ctx, void* ptr);
|
||||||
|
|
||||||
} // namespace tf_framework
|
} // namespace tf_framework
|
||||||
|
@ -75,11 +75,11 @@ class AllocOpConverter : public OpConversionPattern<AllocOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
// Symbolic operands that bind to the symbols of the memref's layout map are
|
// Symbolic operands that bind to the symbols of the memref's layout map are
|
||||||
// not supported by AllocRawOp.
|
// not supported by TFAllocOp.
|
||||||
if (alloc.getNumSymbolicOperands() != 0) {
|
if (alloc.getNumSymbolicOperands() != 0) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<AllocRawOp>(alloc, alloc.getType(), ctx,
|
rewriter.replaceOpWithNewOp<TFAllocOp>(alloc, alloc.getType(), ctx,
|
||||||
operands);
|
operands);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@ class AllocOpConverter : public OpConversionPattern<AllocOp> {
|
|||||||
|
|
||||||
// Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
|
// Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
|
||||||
// arg of the parent function.
|
// arg of the parent function.
|
||||||
class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
|
class TFDeallocOpConverter : public OpConversionPattern<DeallocOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<DeallocOp>::OpConversionPattern;
|
using OpConversionPattern<DeallocOp>::OpConversionPattern;
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
DeallocOp::Adaptor transformed(operands);
|
DeallocOp::Adaptor transformed(operands);
|
||||||
rewriter.replaceOpWithNewOp<DeallocRawOp>(dealloc, ctx,
|
rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, ctx,
|
||||||
transformed.memref());
|
transformed.memref());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -118,7 +118,7 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
|
|||||||
|
|
||||||
void PopulateEmbedTFFrameworkConversionPatterns(
|
void PopulateEmbedTFFrameworkConversionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
patterns->insert<AllocOpConverter, DeallocOpConverter, FuncOpConverter>(
|
patterns->insert<AllocOpConverter, TFDeallocOpConverter, FuncOpConverter>(
|
||||||
context);
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||||
@ -30,8 +31,8 @@ namespace {
|
|||||||
using LLVM::LLVMFuncOp;
|
using LLVM::LLVMFuncOp;
|
||||||
using LLVM::LLVMType;
|
using LLVM::LLVMType;
|
||||||
|
|
||||||
static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc_raw";
|
static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
|
||||||
static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc_raw";
|
static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
|
||||||
|
|
||||||
/// Base class for patterns converting TF Framework ops to function calls.
|
/// Base class for patterns converting TF Framework ops to function calls.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
@ -60,18 +61,18 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
|
|||||||
virtual LLVMType GetFuncType() const = 0;
|
virtual LLVMType GetFuncType() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
|
class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
|
||||||
public:
|
public:
|
||||||
using ConvertToLLVMCallOpPattern<AllocRawOp>::ConvertToLLVMCallOpPattern;
|
using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
Operation *op, ArrayRef<Value> operands,
|
Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
AllocRawOp alloc_raw_op = cast<AllocRawOp>(op);
|
TFAllocOp tf_alloc_op = cast<TFAllocOp>(op);
|
||||||
AllocRawOp::Adaptor transformed(operands);
|
TFAllocOp::Adaptor transformed(operands);
|
||||||
|
|
||||||
MemRefType memref_type = alloc_raw_op.getType();
|
MemRefType memref_type = tf_alloc_op.getType();
|
||||||
|
|
||||||
// Get memref descriptor sizes.
|
// Get memref descriptor sizes.
|
||||||
SmallVector<Value, 4> sizes;
|
SmallVector<Value, 4> sizes;
|
||||||
@ -82,13 +83,28 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
|
|||||||
Value num_bytes = getCumulativeSizeInBytes(
|
Value num_bytes = getCumulativeSizeInBytes(
|
||||||
loc, memref_type.getElementType(), sizes, rewriter);
|
loc, memref_type.getElementType(), sizes, rewriter);
|
||||||
|
|
||||||
|
// Convert `output_index` or set it to -1 if the attribute is missing.
|
||||||
|
LLVM::LLVMType llvmInt32Type =
|
||||||
|
LLVM::LLVMType::getInt32Ty(rewriter.getContext());
|
||||||
|
Value output_index = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, llvmInt32Type,
|
||||||
|
rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
|
||||||
|
? tf_alloc_op.output_index().getValue()
|
||||||
|
: -1));
|
||||||
|
|
||||||
|
// Convert `candidate_input_indices`.
|
||||||
|
auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray(
|
||||||
|
loc, tf_alloc_op.input_indices(), &rewriter);
|
||||||
|
|
||||||
// Insert function call.
|
// Insert function call.
|
||||||
FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
|
FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
|
||||||
Value allocated_byte_ptr =
|
Value allocated_byte_ptr =
|
||||||
rewriter
|
rewriter
|
||||||
.create<LLVM::CallOp>(
|
.create<LLVM::CallOp>(
|
||||||
loc, getVoidPtrType(), tf_func_ref,
|
loc, getVoidPtrType(), tf_func_ref,
|
||||||
llvm::makeArrayRef({transformed.ctx(), num_bytes}))
|
llvm::makeArrayRef({transformed.ctx(), num_bytes, output_index,
|
||||||
|
candidates_count_and_ptr.first,
|
||||||
|
candidates_count_and_ptr.second}))
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
|
MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
|
||||||
@ -103,10 +119,18 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
|
|||||||
StringRef GetFuncName() const override { return kCInterfaceAlloc; }
|
StringRef GetFuncName() const override { return kCInterfaceAlloc; }
|
||||||
|
|
||||||
LLVMType GetFuncType() const override {
|
LLVMType GetFuncType() const override {
|
||||||
|
LLVMType llvm_i32_type =
|
||||||
|
LLVM::LLVMType::getInt32Ty(getDialect().getContext());
|
||||||
|
LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
|
||||||
LLVMType llvm_void_ptr_type = getVoidPtrType();
|
LLVMType llvm_void_ptr_type = getVoidPtrType();
|
||||||
return LLVM::LLVMType::getFunctionTy(
|
return LLVMType::getFunctionTy(
|
||||||
llvm_void_ptr_type,
|
llvm_void_ptr_type,
|
||||||
llvm::makeArrayRef({llvm_void_ptr_type, getIndexType()}),
|
llvm::makeArrayRef(
|
||||||
|
{/*void* op_kernel_ctx*/ llvm_void_ptr_type,
|
||||||
|
/*size_t num_bytes*/ getIndexType(),
|
||||||
|
/*int32_t output_index*/ llvm_i32_type,
|
||||||
|
/*int32_t num_candidates*/ llvm_i32_type,
|
||||||
|
/*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}),
|
||||||
/*isVarArg=*/false);
|
/*isVarArg=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,16 +168,53 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
|
|||||||
}
|
}
|
||||||
return memref_desc;
|
return memref_desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Value, Value> ConvertI32ArrayAttrToStackAllocatedArray(
|
||||||
|
Location loc, llvm::Optional<ArrayAttr> attr,
|
||||||
|
ConversionPatternRewriter *rewriter) const {
|
||||||
|
LLVMType llvm_i32_type =
|
||||||
|
LLVM::LLVMType::getInt32Ty(getDialect().getContext());
|
||||||
|
LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
|
||||||
|
|
||||||
|
// If the attribute is missing or empty, set the element count to 0 and
|
||||||
|
// return NULL.
|
||||||
|
if (!attr.hasValue() || attr.getValue().empty()) {
|
||||||
|
Value zero = rewriter->create<LLVM::ConstantOp>(
|
||||||
|
loc, llvm_i32_type, rewriter->getI32IntegerAttr(0));
|
||||||
|
Value null_ptr = rewriter->create<LLVM::NullOp>(loc, llvm_i32_ptr_type);
|
||||||
|
return std::make_pair(zero, null_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate array to store the elements.
|
||||||
|
auto &array_attr = attr.getValue();
|
||||||
|
Value array_size = rewriter->create<LLVM::ConstantOp>(
|
||||||
|
loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size()));
|
||||||
|
Value array_ptr = rewriter->create<LLVM::AllocaOp>(
|
||||||
|
loc, llvm_i32_ptr_type, array_size, /*alignment=*/0);
|
||||||
|
|
||||||
|
for (auto &dim : llvm::enumerate(array_attr)) {
|
||||||
|
Value index = rewriter->create<LLVM::ConstantOp>(
|
||||||
|
loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index()));
|
||||||
|
Value elem_ptr = rewriter->create<LLVM::GEPOp>(loc, llvm_i32_ptr_type,
|
||||||
|
array_ptr, index);
|
||||||
|
Value elem = rewriter->create<LLVM::ConstantOp>(
|
||||||
|
loc, llvm_i32_type,
|
||||||
|
rewriter->getI32IntegerAttr(
|
||||||
|
dim.value().cast<IntegerAttr>().getInt()));
|
||||||
|
rewriter->create<LLVM::StoreOp>(loc, elem, elem_ptr);
|
||||||
|
}
|
||||||
|
return std::make_pair(array_size, array_ptr);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern<DeallocRawOp> {
|
class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
|
||||||
public:
|
public:
|
||||||
using ConvertToLLVMCallOpPattern<DeallocRawOp>::ConvertToLLVMCallOpPattern;
|
using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
Operation *op, ArrayRef<Value> operands,
|
Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
DeallocRawOp::Adaptor transformed(operands);
|
TFDeallocOp::Adaptor transformed(operands);
|
||||||
MemRefDescriptor memref(transformed.memref());
|
MemRefDescriptor memref(transformed.memref());
|
||||||
|
|
||||||
Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
|
Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
|
||||||
@ -194,7 +255,7 @@ class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
|
|||||||
void PopulateTFFrameworkToLLVMConversionPatterns(
|
void PopulateTFFrameworkToLLVMConversionPatterns(
|
||||||
LLVMTypeConverter *converter, OwningRewritePatternList *patterns) {
|
LLVMTypeConverter *converter, OwningRewritePatternList *patterns) {
|
||||||
patterns->insert<NullContextOpConverter>(*converter);
|
patterns->insert<NullContextOpConverter>(*converter);
|
||||||
patterns->insert<AllocRawOpConverter, DeallocRawOpConverter>(*converter);
|
patterns->insert<TFAllocOpConverter, TFDeallocOpConverter>(*converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tf_framework
|
} // namespace tf_framework
|
||||||
|
Loading…
Reference in New Issue
Block a user