From 1ef42063180285108d0e4c45159ed56075255aff Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon@google.com>
Date: Wed, 21 Oct 2020 07:06:15 -0700
Subject: [PATCH] [MLIR][KERNEL_GEN] Extend `tf_framework.alloc` to support
 input forwarding.

This PR changes the op definition and the lowering to LLVM including `tf_framework_c_interface`.
The lowering is tested with e2e test in `tf_framework_external_calls.mlir`.
The old `tf_framework_external_calls.mlir` was renamed to `tf_framework_embed_and_call.mlir` because it tests two passes at once.

PiperOrigin-RevId: 338256141
Change-Id: I4fd53ca30caae6064f7d6e15c36ca178f9ac8dee
---
 .../tools/kernel_gen/ir/tf_framework_ops.cc   |  4 +-
 .../tools/kernel_gen/ir/tf_framework_ops.td   | 27 +++---
 .../kernel_gen/tests/embed_tf_framework.mlir  |  4 +-
 .../mlir/tools/kernel_gen/tests/invalid.mlir  |  2 +-
 .../mlir/tools/kernel_gen/tests/ops.mlir      | 25 ++++--
 .../tests/tf_framework_legalize_to_llvm.mlir  | 37 +++++---
 .../kernel_gen/tf_framework_c_interface.cc    | 34 ++++++-
 .../kernel_gen/tf_framework_c_interface.h     |  7 +-
 .../transforms/embed_tf_framework.cc          | 14 +--
 .../tf_framework_legalize_to_llvm.cc          | 89 ++++++++++++++++---
 10 files changed, 179 insertions(+), 64 deletions(-)

diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc
index b3d92773be4..676e1849318 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc
@@ -61,10 +61,10 @@ LogicalResult Verify(OpTy op) {
 }
 
 //===----------------------------------------------------------------------===//
-// AllocRawOp
+// TFAllocOp
 //===----------------------------------------------------------------------===//
 template <>
-LogicalResult Verify<AllocRawOp>(AllocRawOp op) {
+LogicalResult Verify<TFAllocOp>(TFAllocOp op) {
   // Check that the total number of operands matches the number of dynamic
   // dimensions specified in the memref type.
   unsigned result_dyn_dims = op.getType().getNumDynamicDims();
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td
index e6e29bcbdc2..2f3e0f6f5fa 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td
@@ -49,21 +49,28 @@ class TFFramework_Op<string mnemonic, list<OpTrait> traits = []> :
 }
 
 //===----------------------------------------------------------------------===//
-// AllocRawOp
+// TFAllocOp
 //===----------------------------------------------------------------------===//
-def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw",
+def TFFramework_TFAllocOp : TFFramework_Op<"alloc",
     [MemoryEffects<[MemAlloc<DefaultResource>]>]> {
   let summary = "allocation of tensors that uses TF Framework";
   let description = [{
     Allocation of tensors during kernel execution in the Compute method.
 
-    This should be used to allocate any temporary or output memref.
-    Corresponds to `Allocator::AllocateRaw` in
-    tensorflow/core/framework/allocator.h.
+    This should be used to allocate any temporary or output memref. If
+    `output_index` and `input_indices` are given, attempts to forward one of
+    the input tensors to the output by calling `OpKernelContext::forward_input`.
+
+    If the attributes are missing or the forwarding fails, calls
+    `Allocator::AllocateRaw` in tensorflow/core/framework/allocator.h.
   }];
 
-  let arguments = (ins TFFramework_OpKernelContextType:$ctx,
-                   Variadic<Index>:$dyn_sizes);
+  let arguments = (ins
+    TFFramework_OpKernelContextType:$ctx,
+    Variadic<Index>:$dyn_sizes,
+    OptionalAttr<I32ArrayAttr>:$input_indices,
+    OptionalAttr<I32Attr>:$output_index
+  );
   let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$result);
 
   let builders = [
@@ -92,16 +99,16 @@ def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw",
 }
 
 //===----------------------------------------------------------------------===//
-// DeallocRawOp
+// TFDeallocOp
 //===----------------------------------------------------------------------===//
-def TFFramework_DeallocRawOp : TFFramework_Op<"dealloc_raw",
+def TFFramework_TFDeallocOp : TFFramework_Op<"dealloc",
     [MemoryEffects<[MemFree]>]> {
   let summary = "deallocation of tensors that uses TF Framework";
   let description = [{
     Deallocation of tensors during kernel execution in the Compute method.
 
     This should be used to deallocate any temporary memref that was allocated
-    with `tf_framework.alloc_raw`.
+    with `tf_framework.alloc`.
     Corresponds to `Allocator::DeallocateRaw` in
     tensorflow/core/framework/allocator.h.
   }];
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir
index bb0f1926cda..5d0beb7c7fe 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/embed_tf_framework.mlir
@@ -10,9 +10,9 @@ func @tf_entry(%size_0 : index , %size_2 : index) -> index
   dealloc %buf : memref<?x10x?xf32>
   std.return %size_0 : index
 }
-// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc_raw
+// CHECK-NEXT: [[VAL_3:%.*]] = tf_framework.alloc
 // CHECK-SAME:   ([[CTX]], [[SIZE_0]], [[SIZE_2]]) : memref<?x10x?xf32>
-// CHECK-NEXT: tf_framework.dealloc_raw([[CTX]], [[VAL_3]]) : memref<?x10x?xf32>
+// CHECK-NEXT: tf_framework.dealloc([[CTX]], [[VAL_3]]) : memref<?x10x?xf32>
 // CHECK-NEXT: return [[SIZE_0]] : index
 
 // -----
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir
index 1d1b3319515..1d3d5e485fb 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/invalid.mlir
@@ -2,6 +2,6 @@
 
 func @alloc_raw(%ctx: !tf_framework.op_kernel_context, %size : index) {
   // expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}}
-  %buf = tf_framework.alloc_raw(%ctx, %size) : memref<?x10x?xi8>
+  %buf = tf_framework.alloc(%ctx, %size) : memref<?x10x?xi8>
   return
 }
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir
index fc8e7c97ec8..aa291c4c439 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/ops.mlir
@@ -4,17 +4,28 @@
 // Verify the generic form can be parsed.
 // RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt | FileCheck %s
 
-// CHECK-LABEL: func @alloc_raw
-func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
+// CHECK-LABEL: func @alloc
+func @alloc(%ctx: !tf_framework.op_kernel_context,
                    %size_0 : index , %size_2 : index) {
-  %buf_0 = tf_framework.alloc_raw(%ctx) : memref<10xi8>
-  %buf_1 = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
+  %buf_0 = tf_framework.alloc(%ctx) : memref<10xi8>
+  %buf_1 = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
   return
 }
 
-// CHECK-LABEL: func @dealloc_raw
-func @dealloc_raw(%ctx: !tf_framework.op_kernel_context, %memref : memref<?x10xf32>) {
-  tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32>
+// CHECK-LABEL: func @forwarding_alloc
+func @forwarding_alloc(%ctx: !tf_framework.op_kernel_context,
+                       %size_0 : index , %size_2 : index) {
+  %buf = tf_framework.alloc(%ctx, %size_0, %size_2) {
+    input_indices = [0 : i32, 1 : i32],
+    output_index = 0 : i32
+  } : memref<?x10x?xi8>
+  return
+}
+
+// CHECK-LABEL: func @dealloc
+func @dealloc(%ctx: !tf_framework.op_kernel_context,
+              %memref : memref<?x10xf32>) {
+  tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32>
   return
 }
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir
index b943321e95b..44f8297a99f 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir
@@ -1,15 +1,15 @@
-// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file | FileCheck %s
+// RUN: kernel-gen-opt %s -tf-kernel-to-llvm -split-input-file --print-ir-after-all | FileCheck %s
 
-// CHECK: llvm.func @_mlir_ciface_tf_alloc_raw
-// CHECK-SAME:  (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
+// CHECK: llvm.func @_mlir_ciface_tf_alloc
+// CHECK-SAME:  (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>) -> !llvm.ptr<i8>
 
-// CHECK-LABEL: llvm.func @alloc_raw(
+// CHECK-LABEL: llvm.func @alloc(
 // CHECK-SAME:    [[TF_CTX:%.*]]: !llvm.ptr<i8>,
 // CHECK-SAME:    [[SIZE_0:%.*]]: !llvm.i64,
 // CHECK-SAME:    [[SIZE_2:%.*]]: !llvm.i64) -> [[DESC_TY:!.*]] {
-func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
+func @alloc(%ctx: !tf_framework.op_kernel_context,
                 %size_0 : index , %size_2 : index) -> memref<?x10x?xf32> {
-  %buf = tf_framework.alloc_raw(%ctx, %size_0, %size_2) : memref<?x10x?xf32>
+  %buf = tf_framework.alloc(%ctx, %size_0, %size_2) : memref<?x10x?xf32>
   std.return %buf : memref<?x10x?xf32>
 }
 // Compute number of elements.
@@ -25,10 +25,19 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
 // CHECK: [[SIZE_OF_FLOAT:%.*]] = llvm.ptrtoint [[GEP]]
 // CHECK-SAME:            !llvm.ptr<float> to !llvm.i64
 
-// Allocate memory.
+// Compute total size in bytes.
 // CHECK: [[NUM_BYTES:%.*]] = llvm.mul [[NUM_ELEM_1]], [[SIZE_OF_FLOAT]]
-// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]])
-// CHECK-SAME:                  (!llvm.ptr<i8>, !llvm.i64) -> !llvm.ptr<i8>
+
+// Compute output index (-1) and candidate indices (0, NULL).
+// CHECK: [[OUTPUT_INDEX:%.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32
+// CHECK-NEXT: [[NUM_CANDIDATES:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK-NEXT: [[CANDIDATES_PTR:%.*]] = llvm.mlir.null : !llvm.ptr<i32>
+
+// Allocate memory.
+// CHECK: [[BYTES_PTR:%.*]] = llvm.call @{{.*}}([[TF_CTX]], [[NUM_BYTES]],
+// CHECK-SAME: [[OUTPUT_INDEX]], [[NUM_CANDIDATES]], [[CANDIDATES_PTR]])
+// CHECK-SAME: (!llvm.ptr<i8>, !llvm.i64, !llvm.i32, !llvm.i32, !llvm.ptr<i32>
+// CHECK-SAME: ) -> !llvm.ptr<i8>
 
 // Build memref descriptor.
 // CHECK: [[DESC_0:%.*]] = llvm.mlir.undef : [[DESC_TY]]
@@ -55,13 +64,13 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context,
 
 // -----
 
-// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr<i8>, !llvm.ptr<i8>)
+// CHECK: llvm.func @_mlir_ciface_tf_dealloc(!llvm.ptr<i8>, !llvm.ptr<i8>)
 
-// CHECK-LABEL: llvm.func @dealloc_raw(
+// CHECK-LABEL: llvm.func @dealloc(
 // CHECK-SAME:    [[TF_CTX:%.*]]: !llvm.ptr<i8>,
-func @dealloc_raw(%ctx: !tf_framework.op_kernel_context,
+func @dealloc(%ctx: !tf_framework.op_kernel_context,
                   %memref : memref<?x10xf32>) {
-  tf_framework.dealloc_raw(%ctx, %memref) : memref<?x10xf32>
+  tf_framework.dealloc(%ctx, %memref) : memref<?x10xf32>
   return
 }
 // Extract allocated ptr from the memref descriptor.
@@ -71,5 +80,5 @@ func @dealloc_raw(%ctx: !tf_framework.op_kernel_context,
 // CHECK-SAME:                   !llvm.ptr<float> to !llvm.ptr<i8>
 
 // Deallocate.
-// CHECK: llvm.call @_mlir_ciface_tf_dealloc_raw(
+// CHECK: llvm.call @_mlir_ciface_tf_dealloc(
 // CHECK-SAME: [[TF_CTX]], [[VOID_PTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
index e75db59d885..2b2625b4d59 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc
@@ -24,23 +24,49 @@ namespace tf_framework {
 namespace {
 
 using tensorflow::Allocator;
+using tensorflow::AllocatorAttributes;
 
 Allocator* GetAllocator(void* op_kernel_ctx) {
   auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
   // TODO(pifon): Figure out how to set AllocatorAttributes correctly.
-  tensorflow::AllocatorAttributes attrs;
+  AllocatorAttributes attrs;
   return ctx->get_allocator(attrs);
 }
 
 }  // namespace
 
-extern "C" void* _mlir_ciface_tf_alloc_raw(void* op_kernel_ctx,
-                                           size_t num_bytes) {
+extern "C" void* _mlir_ciface_tf_alloc(void* op_kernel_ctx, size_t num_bytes,
+                                       int32_t output_index,
+                                       int32_t num_candidates,
+                                       int32_t* candidate_input_indices) {
+  auto* ctx = static_cast<tensorflow::OpKernelContext*>(op_kernel_ctx);
+
+  if (output_index != -1) {
+    auto element_size = ctx->expected_output_dtype(output_index);
+    // Create a 1D shape, because the shapes don't have to match exactly for
+    // input forwarding. Only the number of elements must be the same.
+    tensorflow::TensorShape output_shape;
+    output_shape.AddDim(num_bytes / element_size);
+
+    // Iterate over indices of all inputs that can potentially be used for
+    // forwarding.
+    for (int i = 0; i < num_candidates; ++i) {
+      // TODO(pifon): Expose fetching AllocatorAttributes with the output_index.
+      AllocatorAttributes output_attr;
+      auto tensor = ctx->forward_input(
+          candidate_input_indices[i], output_index, element_size, output_shape,
+          ctx->output_memory_type(output_index), output_attr);
+      if (tensor != nullptr) {
+        return tensor->data();
+      }
+    }
+  }
+  // If no forwarding happened, allocate a chunk of memory.
   return GetAllocator(op_kernel_ctx)
       ->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes);
 }
 
-extern "C" void _mlir_ciface_tf_dealloc_raw(void* op_kernel_ctx, void* ptr) {
+extern "C" void _mlir_ciface_tf_dealloc(void* op_kernel_ctx, void* ptr) {
   GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr);
 }
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h
index 143ebc95932..bf45116f372 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h
@@ -22,10 +22,11 @@ namespace mlir {
 namespace kernel_gen {
 namespace tf_framework {
 
-extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc_raw(
-    void* op_kernel_ctx, size_t num_bytes);
+extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc(
+    void* op_kernel_ctx, size_t num_bytes, int32_t output_index,
+    int32_t num_candidates, int32_t* candidate_input_indices);
 
-extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc_raw(
+extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc(
     void* op_kernel_ctx, void* ptr);
 
 }  // namespace tf_framework
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc
index aa02aefa9d2..3b006c954cf 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc
@@ -75,19 +75,19 @@ class AllocOpConverter : public OpConversionPattern<AllocOp> {
       return failure();
     }
     // Symbolic operands that bind to the symbols of the memref's layout map are
-    // not supported by AllocRawOp.
+    // not supported by TFAllocOp.
     if (alloc.getNumSymbolicOperands() != 0) {
       return failure();
     }
-    rewriter.replaceOpWithNewOp<AllocRawOp>(alloc, alloc.getType(), ctx,
-                                            operands);
+    rewriter.replaceOpWithNewOp<TFAllocOp>(alloc, alloc.getType(), ctx,
+                                           operands);
     return success();
   }
 };
 
 // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
 // arg of the parent function.
-class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
+class TFDeallocOpConverter : public OpConversionPattern<DeallocOp> {
  public:
   using OpConversionPattern<DeallocOp>::OpConversionPattern;
 
@@ -108,8 +108,8 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
       return failure();
     }
     DeallocOp::Adaptor transformed(operands);
-    rewriter.replaceOpWithNewOp<DeallocRawOp>(dealloc, ctx,
-                                              transformed.memref());
+    rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, ctx,
+                                             transformed.memref());
     return success();
   }
 };
@@ -118,7 +118,7 @@ class DeallocOpConverter : public OpConversionPattern<DeallocOp> {
 
 void PopulateEmbedTFFrameworkConversionPatterns(
     MLIRContext *context, OwningRewritePatternList *patterns) {
-  patterns->insert<AllocOpConverter, DeallocOpConverter, FuncOpConverter>(
+  patterns->insert<AllocOpConverter, TFDeallocOpConverter, FuncOpConverter>(
       context);
 }
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc
index 431919c2de7..959f7ecf635 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc
@@ -16,6 +16,7 @@ limitations under the License.
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"  // from @llvm-project
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
@@ -30,8 +31,8 @@ namespace {
 using LLVM::LLVMFuncOp;
 using LLVM::LLVMType;
 
-static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc_raw";
-static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc_raw";
+static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
+static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
 
 /// Base class for patterns converting TF Framework ops to function calls.
 template <typename OpTy>
@@ -60,18 +61,18 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
   virtual LLVMType GetFuncType() const = 0;
 };
 
-class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
+class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
  public:
-  using ConvertToLLVMCallOpPattern<AllocRawOp>::ConvertToLLVMCallOpPattern;
+  using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
 
   LogicalResult matchAndRewrite(
       Operation *op, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    AllocRawOp alloc_raw_op = cast<AllocRawOp>(op);
-    AllocRawOp::Adaptor transformed(operands);
+    TFAllocOp tf_alloc_op = cast<TFAllocOp>(op);
+    TFAllocOp::Adaptor transformed(operands);
 
-    MemRefType memref_type = alloc_raw_op.getType();
+    MemRefType memref_type = tf_alloc_op.getType();
 
     // Get memref descriptor sizes.
     SmallVector<Value, 4> sizes;
@@ -82,13 +83,28 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
     Value num_bytes = getCumulativeSizeInBytes(
         loc, memref_type.getElementType(), sizes, rewriter);
 
+    // Convert `output_index` or set it to -1 if the attribute is missing.
+    LLVM::LLVMType llvmInt32Type =
+        LLVM::LLVMType::getInt32Ty(rewriter.getContext());
+    Value output_index = rewriter.create<LLVM::ConstantOp>(
+        loc, llvmInt32Type,
+        rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
+                                       ? tf_alloc_op.output_index().getValue()
+                                       : -1));
+
+    // Convert `candidate_input_indices`.
+    auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray(
+        loc, tf_alloc_op.input_indices(), &rewriter);
+
     // Insert function call.
     FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
     Value allocated_byte_ptr =
         rewriter
             .create<LLVM::CallOp>(
                 loc, getVoidPtrType(), tf_func_ref,
-                llvm::makeArrayRef({transformed.ctx(), num_bytes}))
+                llvm::makeArrayRef({transformed.ctx(), num_bytes, output_index,
+                                    candidates_count_and_ptr.first,
+                                    candidates_count_and_ptr.second}))
             .getResult(0);
 
     MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
@@ -103,10 +119,18 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
   StringRef GetFuncName() const override { return kCInterfaceAlloc; }
 
   LLVMType GetFuncType() const override {
+    LLVMType llvm_i32_type =
+        LLVM::LLVMType::getInt32Ty(getDialect().getContext());
+    LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
     LLVMType llvm_void_ptr_type = getVoidPtrType();
-    return LLVM::LLVMType::getFunctionTy(
+    return LLVMType::getFunctionTy(
         llvm_void_ptr_type,
-        llvm::makeArrayRef({llvm_void_ptr_type, getIndexType()}),
+        llvm::makeArrayRef(
+            {/*void* op_kernel_ctx*/ llvm_void_ptr_type,
+             /*size_t num_bytes*/ getIndexType(),
+             /*int32_t output_index*/ llvm_i32_type,
+             /*int32_t num_candidates*/ llvm_i32_type,
+             /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}),
         /*isVarArg=*/false);
   }
 
@@ -144,16 +168,53 @@ class AllocRawOpConverter : public ConvertToLLVMCallOpPattern<AllocRawOp> {
     }
     return memref_desc;
   }
+
+  std::pair<Value, Value> ConvertI32ArrayAttrToStackAllocatedArray(
+      Location loc, llvm::Optional<ArrayAttr> attr,
+      ConversionPatternRewriter *rewriter) const {
+    LLVMType llvm_i32_type =
+        LLVM::LLVMType::getInt32Ty(getDialect().getContext());
+    LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
+
+    // If the attribute is missing or empty, set the element count to 0 and
+    // return NULL.
+    if (!attr.hasValue() || attr.getValue().empty()) {
+      Value zero = rewriter->create<LLVM::ConstantOp>(
+          loc, llvm_i32_type, rewriter->getI32IntegerAttr(0));
+      Value null_ptr = rewriter->create<LLVM::NullOp>(loc, llvm_i32_ptr_type);
+      return std::make_pair(zero, null_ptr);
+    }
+
+    // Allocate array to store the elements.
+    auto &array_attr = attr.getValue();
+    Value array_size = rewriter->create<LLVM::ConstantOp>(
+        loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size()));
+    Value array_ptr = rewriter->create<LLVM::AllocaOp>(
+        loc, llvm_i32_ptr_type, array_size, /*alignment=*/0);
+
+    for (auto &dim : llvm::enumerate(array_attr)) {
+      Value index = rewriter->create<LLVM::ConstantOp>(
+          loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index()));
+      Value elem_ptr = rewriter->create<LLVM::GEPOp>(loc, llvm_i32_ptr_type,
+                                                     array_ptr, index);
+      Value elem = rewriter->create<LLVM::ConstantOp>(
+          loc, llvm_i32_type,
+          rewriter->getI32IntegerAttr(
+              dim.value().cast<IntegerAttr>().getInt()));
+      rewriter->create<LLVM::StoreOp>(loc, elem, elem_ptr);
+    }
+    return std::make_pair(array_size, array_ptr);
+  }
 };
 
-class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern<DeallocRawOp> {
+class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
  public:
-  using ConvertToLLVMCallOpPattern<DeallocRawOp>::ConvertToLLVMCallOpPattern;
+  using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
 
   LogicalResult matchAndRewrite(
       Operation *op, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
-    DeallocRawOp::Adaptor transformed(operands);
+    TFDeallocOp::Adaptor transformed(operands);
     MemRefDescriptor memref(transformed.memref());
 
     Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
@@ -194,7 +255,7 @@ class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
 void PopulateTFFrameworkToLLVMConversionPatterns(
     LLVMTypeConverter *converter, OwningRewritePatternList *patterns) {
   patterns->insert<NullContextOpConverter>(*converter);
-  patterns->insert<AllocRawOpConverter, DeallocRawOpConverter>(*converter);
+  patterns->insert<TFAllocOpConverter, TFDeallocOpConverter>(*converter);
 }
 
 }  // namespace tf_framework