[MLIR][KERNEL_GEN] Extend tf_framework.alloc to support input forwarding.

This PR changes the op definition and the lowering to LLVM including `tf_framework_c_interface`.
The lowering is tested with e2e test in `tf_framework_external_calls.mlir`.
The old `tf_framework_external_calls.mlir` was renamed to `tf_framework_embed_and_call.mlir` because it tests two passes at once.

PiperOrigin-RevId: 338256141
Change-Id: I4fd53ca30caae6064f7d6e15c36ca178f9ac8dee
This commit is contained in:
Alexander Belyaev 2020-10-21 07:06:15 -07:00 committed by TensorFlower Gardener
parent b39bc5f1f0
commit 1ef4206318
10 changed files with 179 additions and 64 deletions

View File

@ -61,10 +61,10 @@ LogicalResult Verify(OpTy op) {
}
//===----------------------------------------------------------------------===//
// AllocRawOp
// TFAllocOp
//===----------------------------------------------------------------------===//
template <>
LogicalResult Verify<AllocRawOp>(AllocRawOp op) {
LogicalResult Verify<TFAllocOp>(TFAllocOp op) {
// Check that the total number of operands matches the number of dynamic
// dimensions specified in the memref type.
unsigned result_dyn_dims = op.getType().getNumDynamicDims();

View File

@ -49,21 +49,28 @@ class TFFramework_Op<string mnemonic, list<OpTrait> traits = []> :
}
//===----------------------------------------------------------------------===//
// AllocRawOp
// TFAllocOp
//===----------------------------------------------------------------------===//
def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw",
def TFFramework_TFAllocOp : TFFramework_Op<"alloc",
[MemoryEffects<[MemAlloc<DefaultResource>]>]> {
let summary = "allocation of tensors that uses TF Framework";
let description = [{
Allocation of tensors during kernel execution in the Compute method.
This should be used to allocate any temporary or output memref.
Corresponds to `Allocator::AllocateRaw` in
tensorflow/core/framework/allocator.h.
This should be used to allocate any temporary or output memref. If
`output_index` and `input_indices` are given, attempts to forward one of
the input tensors to the output by calling `OpKernelContext::forward_input`.
If the attributes are missing or the forwarding fails, calls
`Allocator::AllocateRaw` in tensorflow/core/framework/allocator.h.
}];
let arguments = (ins TFFramework_OpKernelContextType:$ctx,
Variadic<Index>:$dyn_sizes);
let arguments = (ins
TFFramework_OpKernelContextType:$ctx,
Variadic<Index>:$dyn_sizes,
OptionalAttr<I32ArrayAttr>:$input_indices,
OptionalAttr<I32Attr>:$output_index
);
let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$result);
let builders = [
@ -92,16 +99,16 @@ def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw",
}
//===----------------------------------------------------------------------===//
// DeallocRawOp
// TFDeallocOp
//===----------------------------------------------------------------------===//
def TFFramework_DeallocRawOp : TFFramework_Op<"dealloc_raw",
def TFFramework_TFDeallocOp : TFFramework_Op<"dealloc",
[MemoryEffects<[MemFree]>]> {
let summary = "deallocation of tensors that uses TF Framework";
let description = [{
Deallocation of tensors during kernel execution in the Compute method.
This should be used to deallocate any temporary memref that was allocated
with `tf_framework.alloc_raw`.
with `tf_framework.alloc`.
Corresponds to `Allocator::DeallocateRaw` in
tensorflow/core/framework/allocator.h.
}];

View File

@ -10,9 +10,9 @@ func @tf_entry(%size_0 : index , %size_2 : index) -> index
dealloc %buf : memref<?x10x?xf32>
std.return %size_0 : index
}
// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc_raw
// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc
// CHECK-SAME: ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref<?x10x?xf32>
// CHECK-NEXT: tf_framework.dealloc_raw([[CTX]], [[VAL_3]]) : memref<?x10x?xf32>
// CHECK-NEXT: tf_framework.dealloc([[CTX]], [[VAL_3]]) : memref<?x10x?xf32>
// CHECK-NEXT: return [[SIZE_0]] : index
// -----

View File

@ -2,6 +2,6 @@
func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) {
// expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}}
%buf = tf_framework.alloc_raw(%ctx, %size) : memref<?x10x?xi8>
%buf = tf_framework.alloc(%ctx, %size) : memref<?x10x?xi8>
return
}

View File

@ -4,17 +4,28 @@
// Verify the generic form can be parsed.
// RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s
// CHECK-LABEL: func @alloc_raw
func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
// CHECK-LABEL: func @alloc
func @alloc(%ctx: !tf_framework.op_kernel_context,
%size_0 : index , %size_2 : index) {
%buf_0 = tf_framework.alloc_raw(%ctx) : memref<10xi8>
%buf_1 = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
%buf_0 = tf_framework.alloc(%ctx) : memref<10xi8>
%buf_1 = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
return
}
// CHECK-LABEL: func @dealloc_raw
func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, %memref : memref<?x10xf32>) {
tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32>
// CHECK-LABEL: func @forwarding_alloc
func @forwarding_alloc(%ctx: !tf_framework.op_kernel_context,
%size_0 : index , %size_2 : index) {
%buf = tf_framework.alloc(%ctx, %size_0, %size_2) {
input_indices = [0 : i32, 1 : i32],
output_index = 0 : i32
} : memref<?x10x?xi8>
return
}
// CHECK-LABEL: func @dealloc
func @dealloc(%ctx: !tf_framework.op_kernel_context,
%memref : memref<?x10xf32>) {
tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32>
return
}

View File

@ -1,15 +1,15 @@
// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s
// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file --print-ir-after-all | FileCheck %s
// CHECK: llvm.func @_mlir_ciface_tf_alloc_raw
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
// CHECK: llvm.func @_mlir_ciface_tf_alloc
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>) -> !llvm.ptr<i8>
// CHECK-LABEL: llvm.func @alloc_raw(
// CHECK-LABEL: llvm.func @alloc(
// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>,
// CHECK-SAME: [[SIZE_0:%.*]]: !llvm.i64,
// CHECK-SAME: [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] {
func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
func @alloc(%ctx: !tf_framework.op_kernel_context,
%size_0 : index , %size_2 : index) -> memref<?x10x?xf32> {
%buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xf32>
%buf = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xf32>
std.return %buf : memref<?x10x?xf32>
}
// Compute number of elements.
@ -25,10 +25,19 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
// CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]]
// CHECK-SAME: !llvm.ptr<float> to !llvm.i64
// Allocate memory.
// Compute total size in bytes.
// CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]]
// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]])
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
// Compute output index (-1) and candidate indices (0, NULL).
// CHECK: [[OUTPUT_INDEX:%.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
// CHECK-NEXT: [[NUM_CANDIDATES:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK-NEXT: [[CANDIDATES_PTR:%.*]] = llvm.mlir.null : !llvm.ptr<i32>
// Allocate memory.
// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]],
// CHECK-SAME: [[OUTPUT_INDEX]], [[NUM_CANDIDATES]], [[CANDIDATES_PTR]])
// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>
// CHECK-SAME: ) -> !llvm.ptr<i8>
// Build memref descriptor.
// CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]]
@ -55,13 +64,13 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
// -----
// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr<i8>, !llvm.ptr<i8>)
// CHECK: llvm.func @_mlir_ciface_tf_dealloc(!llvm.ptr<i8>, !llvm.ptr<i8>)
// CHECK-LABEL: llvm.func @dealloc_raw(
// CHECK-LABEL: llvm.func @dealloc(
// CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr<i8>,
func @dealloc_raw(%ctx: !tf_framework.op_kernel_context,
func @dealloc(%ctx: !tf_framework.op_kernel_context,
%memref : memref<?x10xf32>) {
tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32>
tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32>
return
}
// Extract allocated ptr from the memref descriptor.
@ -71,5 +80,5 @@ func @dealloc_raw(%ctx: !tf_framework.op_kernel_context,
// CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<i8>
// Deallocate.
// CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw(
// CHECK: llvm.call @_mlir_ciface_tf_dealloc(
// CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()

View File

@ -24,23 +24,49 @@ namespace tf_framework {
namespace {
using tensorflow::Allocator;
using tensorflow::AllocatorAttributes;
Allocator* GetAllocator(void* op_kernel_ctx) {
auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
// TODO(pifon): Figure out how to set AllocatorAttributes correctly.
tensorflow::AllocatorAttributes attrs;
AllocatorAttributes attrs;
return ctx->get_allocator(attrs);
}
} // namespace
extern "C" void* _mlir_ciface_tf_alloc_raw(void* op_kernel_ctx,
size_t num_bytes) {
extern "C" void* _mlir_ciface_tf_alloc(void* op_kernel_ctx, size_t num_bytes,
int32_t output_index,
int32_t num_candidates,
int32_t* candidate_input_indices) {
auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
if (output_index != -1) {
auto element_size = ctx->expected_output_dtype(output_index);
// Create a 1D shape, because the shapes don't have to match exactly for
// input forwarding. Only the number of elements must be the same.
tensorflow::TensorShape output_shape;
output_shape.AddDim(num_bytes / element_size);
// Iterate over indices of all inputs that can potentially be used for
// forwarding.
for (int i = 0; i < num_candidates; ++i) {
// TODO(pifon): Expose fetching AllocatorAttributes with the output_index.
AllocatorAttributes output_attr;
auto tensor = ctx->forward_input(
candidate_input_indices[i], output_index, element_size, output_shape,
ctx->output_memory_type(output_index), output_attr);
if (tensor != nullptr) {
return tensor->data();
}
}
}
// If no forwarding happened, allocate a chunk of memory.
return GetAllocator(op_kernel_ctx)
->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes);
}
extern "C" void _mlir_ciface_tf_dealloc_raw(void* op_kernel_ctx, void* ptr) {
extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr);
}

View File

@ -22,10 +22,11 @@ namespace mlir {
namespace kernel_gen {
namespace tf_framework {
extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc_raw(
void* op_kernel_ctx, size_t num_bytes);
extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc(
void* op_kernel_ctx, size_t num_bytes, int32_t output_index,
int32_t num_candidates, int32_t* candidate_input_indices);
extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc_raw(
extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc(
void* op_kernel_ctx, void* ptr);
} // namespace tf_framework

View File

@ -75,11 +75,11 @@ class AllocOpConverter : public OpConversionPattern<AllocOp> {
return failure();
}
// Symbolic operands that bind to the symbols of the memref's layout map are
// not supported by AllocRawOp.
// not supported by TFAllocOp.
if (alloc.getNumSymbolicOperands() != 0) {
return failure();
}
rewriter.replaceOpWithNewOp<AllocRawOp>(alloc, alloc.getType(), ctx,
rewriter.replaceOpWithNewOp<TFAllocOp>(alloc, alloc.getType(), ctx,
operands);
return success();
}
@ -87,7 +87,7 @@ class AllocOpConverter : public OpConversionPattern<AllocOp> {
// Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
// arg of the parent function.
class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
class TFDeallocOpConverter : public OpConversionPattern<DeallocOp> {
public:
using OpConversionPattern<DeallocOp>::OpConversionPattern;
@ -108,7 +108,7 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
return failure();
}
DeallocOp::Adaptor transformed(operands);
rewriter.replaceOpWithNewOp<DeallocRawOp>(dealloc, ctx,
rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, ctx,
transformed.memref());
return success();
}
@ -118,7 +118,7 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
void PopulateEmbedTFFrameworkConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<AllocOpConverter, DeallocOpConverter, FuncOpConverter>(
patterns->insert<AllocOpConverter, TFDeallocOpConverter, FuncOpConverter>(
context);
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
@ -30,8 +31,8 @@ namespace {
using LLVM::LLVMFuncOp;
using LLVM::LLVMType;
static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc_raw";
static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc_raw";
static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
/// Base class for patterns converting TF Framework ops to function calls.
template <typename OpTy>
@ -60,18 +61,18 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
virtual LLVMType GetFuncType() const = 0;
};
class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
public:
using ConvertToLLVMCallOpPattern<AllocRawOp>::ConvertToLLVMCallOpPattern;
using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
AllocRawOp alloc_raw_op = cast<AllocRawOp>(op);
AllocRawOp::Adaptor transformed(operands);
TFAllocOp tf_alloc_op = cast<TFAllocOp>(op);
TFAllocOp::Adaptor transformed(operands);
MemRefType memref_type = alloc_raw_op.getType();
MemRefType memref_type = tf_alloc_op.getType();
// Get memref descriptor sizes.
SmallVector<Value, 4> sizes;
@ -82,13 +83,28 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
Value num_bytes = getCumulativeSizeInBytes(
loc, memref_type.getElementType(), sizes, rewriter);
// Convert `output_index` or set it to -1 if the attribute is missing.
LLVM::LLVMType llvmInt32Type =
LLVM::LLVMType::getInt32Ty(rewriter.getContext());
Value output_index = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type,
rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
? tf_alloc_op.output_index().getValue()
: -1));
// Convert `candidate_input_indices`.
auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray(
loc, tf_alloc_op.input_indices(), &rewriter);
// Insert function call.
FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
Value allocated_byte_ptr =
rewriter
.create<LLVM::CallOp>(
loc, getVoidPtrType(), tf_func_ref,
llvm::makeArrayRef({transformed.ctx(), num_bytes}))
llvm::makeArrayRef({transformed.ctx(), num_bytes, output_index,
candidates_count_and_ptr.first,
candidates_count_and_ptr.second}))
.getResult(0);
MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
@ -103,10 +119,18 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
StringRef GetFuncName() const override { return kCInterfaceAlloc; }
LLVMType GetFuncType() const override {
LLVMType llvm_i32_type =
LLVM::LLVMType::getInt32Ty(getDialect().getContext());
LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
LLVMType llvm_void_ptr_type = getVoidPtrType();
return LLVM::LLVMType::getFunctionTy(
return LLVMType::getFunctionTy(
llvm_void_ptr_type,
llvm::makeArrayRef({llvm_void_ptr_type, getIndexType()}),
llvm::makeArrayRef(
{/*void* op_kernel_ctx*/ llvm_void_ptr_type,
/*size_t num_bytes*/ getIndexType(),
/*int32_t output_index*/ llvm_i32_type,
/*int32_t num_candidates*/ llvm_i32_type,
/*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}),
/*isVarArg=*/false);
}
@ -144,16 +168,53 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
}
return memref_desc;
}
std::pair<Value, Value> ConvertI32ArrayAttrToStackAllocatedArray(
Location loc, llvm::Optional<ArrayAttr> attr,
ConversionPatternRewriter *rewriter) const {
LLVMType llvm_i32_type =
LLVM::LLVMType::getInt32Ty(getDialect().getContext());
LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
// If the attribute is missing or empty, set the element count to 0 and
// return NULL.
if (!attr.hasValue() || attr.getValue().empty()) {
Value zero = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(0));
Value null_ptr = rewriter->create<LLVM::NullOp>(loc, llvm_i32_ptr_type);
return std::make_pair(zero, null_ptr);
}
// Allocate array to store the elements.
auto &array_attr = attr.getValue();
Value array_size = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size()));
Value array_ptr = rewriter->create<LLVM::AllocaOp>(
loc, llvm_i32_ptr_type, array_size, /*alignment=*/0);
for (auto &dim : llvm::enumerate(array_attr)) {
Value index = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index()));
Value elem_ptr = rewriter->create<LLVM::GEPOp>(loc, llvm_i32_ptr_type,
array_ptr, index);
Value elem = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type,
rewriter->getI32IntegerAttr(
dim.value().cast<IntegerAttr>().getInt()));
rewriter->create<LLVM::StoreOp>(loc, elem, elem_ptr);
}
return std::make_pair(array_size, array_ptr);
}
};
class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern<DeallocRawOp> {
class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
public:
using ConvertToLLVMCallOpPattern<DeallocRawOp>::ConvertToLLVMCallOpPattern;
using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
DeallocRawOp::Adaptor transformed(operands);
TFDeallocOp::Adaptor transformed(operands);
MemRefDescriptor memref(transformed.memref());
Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
@ -194,7 +255,7 @@ class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
void PopulateTFFrameworkToLLVMConversionPatterns(
LLVMTypeConverter *converter, OwningRewritePatternList *patterns) {
patterns->insert<NullContextOpConverter>(*converter);
patterns->insert<AllocRawOpConverter, DeallocRawOpConverter>(*converter);
patterns->insert<TFAllocOpConverter, TFDeallocOpConverter>(*converter);
}
} // namespace tf_framework