[MLIR][KERNEL_GEN] Add TFFramework MLIR dialect.

PiperOrigin-RevId: 323347540
Change-Id: I247e2b4eaa7f072dd03ba2462ec981dc9d31fd1a
This commit is contained in:
Alexander Belyaev 2020-07-27 06:13:00 -07:00 committed by TensorFlower Gardener
parent 7c0ed7acaa
commit 8fb9691376
11 changed files with 453 additions and 1 deletions

View File

@ -74,7 +74,7 @@ tool_names = [
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir',
'xla-thunks-opt'
'kernel-gen-opt'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -47,6 +47,7 @@ mlir_tf_tools_dirs = [
'tensorflow/compiler/mlir/tensorflow',
'tensorflow/compiler/mlir/tfjs',
'tensorflow/compiler/mlir/xla',
'tensorflow/compiler/mlir/tools/kernel_gen',
'tensorflow/compiler/aot',
'tensorflow/compiler/xla/service/mlir_gpu',
'tensorflow/compiler/xla/service/gpu/tests',

View File

@ -50,3 +50,15 @@ tf_cc_binary(
"@llvm-project//llvm:Support",
],
)
tf_cc_binary(
name = "kernel-gen-opt",
visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__pkg__"],
deps = [
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_dialect_registration",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:MlirOptMain",
],
)

View File

@ -0,0 +1,55 @@
load("//third_party/mlir:tblgen.bzl", "gentbl")
package(
default_visibility = [":friends"],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = ["//tensorflow/compiler/mlir/..."],
)
gentbl(
name = "tf_framework_ops_inc_gen",
tbl_outs = [
("-gen-op-decls", "tf_framework_ops.h.inc"),
("-gen-op-defs", "tf_framework_ops.cc.inc"),
("-gen-struct-attr-decls", "tf_framework_structs.h.inc"),
("-gen-struct-attr-defs", "tf_framework_structs.cc.inc"),
("-gen-dialect-decls", "tf_framework_dialect.h.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "tf_framework_ops.td",
td_srcs = [
"tf_framework_ops.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
)
cc_library(
name = "tf_framework_ops",
srcs = [
"tf_framework_ops.cc",
"tf_framework_ops.cc.inc",
"tf_framework_ops.h.inc",
],
hdrs = ["tf_framework_ops.h"],
deps = [
":tf_framework_ops_inc_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SideEffects",
],
)
cc_library(
name = "tf_framework_dialect_registration",
srcs = ["dialect_registration.cc"],
deps = [
":tf_framework_ops",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)

View File

@ -0,0 +1,21 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
// Static initialization for TF Framework dialect registration.
static mlir::DialectRegistration<
mlir::kernel_gen::tf_framework::TFFrameworkDialect>
tf_framework_ops;

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the tf_framework dialect.
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
namespace mlir {
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_structs.cc.inc"
namespace kernel_gen {
namespace tf_framework {
TFFrameworkDialect::TFFrameworkDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
>();
addTypes<OpKernelContextType>();
}
/// Parse a type registered to this dialect.
Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword)) return Type();
if (keyword == "op_kernel_context") {
return OpKernelContextType::get(getContext());
}
parser.emitError(parser.getNameLoc(), "unknown TF Framework type: ")
<< keyword;
return Type();
}
/// Print a type registered to this dialect.
void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case TFFrameworkTypes::OpKernelContextType:
os << "op_kernel_context";
return;
default:
llvm_unreachable("unexpected TF Framework type kind");
}
}
//===----------------------------------------------------------------------===//
// AllocLikeOp
//===----------------------------------------------------------------------===//
template <typename AllocLikeOp>
static LogicalResult Verify(AllocLikeOp op) {
static_assert(llvm::is_one_of<AllocLikeOp, AllocOutputOp, AllocTempOp>::value,
"applies to only alloc_output or alloc_temp");
// Check that the total number of operands matches the number of dynamic
// dimensions specified in the memref type.
unsigned result_dyn_dims = op.getType().getNumDynamicDims();
unsigned dyn_sizes_count = op.dyn_sizes().size();
if (dyn_sizes_count != result_dyn_dims)
return op.emitOpError()
<< "`dyn_sizes` count " << dyn_sizes_count
<< " does not match dynamic dimensions count in the result type"
<< op.getType();
return success();
}
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
} // namespace tf_framework
} // namespace kernel_gen
} // namespace mlir

View File

@ -0,0 +1,70 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the TFFramework dialect.
//
#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
namespace mlir {
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_structs.h.inc"
namespace kernel_gen {
namespace tf_framework {
namespace TFFrameworkTypes {
enum Kind {
// TODO(pifon): Replace enum value with
// OpKernelContextType = Type::FIRST_TF_FRAMEWORK_TYPE,
// after DialectSymbolRegistry.def is updated.
OpKernelContextType = Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
};
} // namespace TFFrameworkTypes
/// OpKernelContextType corresponds to C++ class OpKernelContext defined in
/// tensorflow/core/framework/op_kernel.h
class OpKernelContextType
: public Type::TypeBase<OpKernelContextType, Type, TypeStorage> {
public:
using Base::Base;
static OpKernelContextType get(MLIRContext *context) {
return Base::get(context, TFFrameworkTypes::Kind::OpKernelContextType);
}
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == TFFrameworkTypes::Kind::OpKernelContextType;
}
};
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.h.inc"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h.inc"
} // namespace tf_framework
} // namespace kernel_gen
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_

View File

@ -0,0 +1,151 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for TF Framework ops.
#ifndef TF_FRAMEWORK_OPS
#define TF_FRAMEWORK_OPS
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def TFFramework_Dialect : Dialect {
let name = "tf_framework";
let summary = "Types and operations for tf_framework dialect";
let description = [{
This dialect contains operations and types for that correspond to
TensorFlow C++ Framework.
}];
let cppNamespace = "kernel_gen::tf_framework";
}
def TFFramework_OpKernelContextType : DialectType<TFFramework_Dialect,
CPred<"$_self.isa<::mlir::kernel_gen::tf_framework::OpKernelContextType>()">,
"op_kernel_construction">,
BuildableType<"$_builder.getType<::mlir::kernel_gen::tf_framework::OpKernelContextType>()"> {
let typeDescription = [{
OpKernelContextType corresponds to C++ class OpKernelContext defined in
tensorflow/core/framework/op_kernel.h
}];
}
def AllocatorAttributes : StructAttr<"AllocatorAttributes",
TFFramework_Dialect, [
StructFieldAttr<"on_host", BoolAttr>,
StructFieldAttr<"nic_compatible", BoolAttr>,
StructFieldAttr<"gpu_compatible", BoolAttr>]> {
let description = "Equivalent to `tensorflow::AllocatorAttributes` in C++";
}
def AllocationAttributes : StructAttr<"AllocationAttributes",
TFFramework_Dialect, [
StructFieldAttr<"no_retry_on_failure",
DefaultValuedAttr<BoolAttr, "false">>,
StructFieldAttr<"allocation_will_be_logged",
DefaultValuedAttr<BoolAttr, "false">>]> {
let description = "Equivalent to `tensorflow::AllocationAttributes` in C++";
}
// Base class for TF Framework dialect ops.
class TFFramework_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TFFramework_Dialect, mnemonic, traits> {
let verifier = "return Verify(*this);";
}
// Base class for TF Framework alloc ops.
class TFFramework_AllocLikeOp<string mnemonic,
Resource resource,
list<OpTrait> traits = []> :
TFFramework_Op<mnemonic,
!listconcat([MemoryEffects<[MemAlloc<resource>]>], traits)> {
let arguments = (ins TFFramework_OpKernelContextType:$op_kernel_ctx,
Variadic<Index>:$dyn_sizes,
OptionalAttr<AllocatorAttributes>:$allocator_attrs,
OptionalAttr<AllocationAttributes>:$allocation_attrs);
let results = (outs Res<AnyMemRef, "", [MemAlloc<resource>]>:$result);
let builders = [
OpBuilder<[{
OpBuilder &builder, OperationState &result, MemRefType memref_type,
Value op_kernel_ctx,
AllocatorAttributes allocator_attrs = AllocatorAttributes(),
AllocationAttributes allocation_attrs = AllocationAttributes()
}], [{
result.addOperands(op_kernel_ctx);
result.types.push_back(memref_type);
if (allocator_attrs)
result.addAttribute("allocator_attrs", allocator_attrs);
if (allocation_attrs)
result.addAttribute("allocation_attrs", allocation_attrs);
}]>,
OpBuilder<[{
OpBuilder &builder, OperationState &result, MemRefType memref_type,
Value op_kernel_ctx, ValueRange dyn_sizes,
AllocatorAttributes allocator_attrs = AllocatorAttributes(),
AllocationAttributes allocation_attrs = AllocationAttributes()
}], [{
build(builder, result, memref_type, op_kernel_ctx, allocator_attrs,
allocation_attrs);
result.addOperands(dyn_sizes);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let assemblyFormat = [{
`(` $op_kernel_ctx (`,` $dyn_sizes^ )? `)` attr-dict `:` type($result)
}];
}
//===----------------------------------------------------------------------===//
// AllocOutputOp
//===----------------------------------------------------------------------===//
def TFFramework_AllocOutputOp
: TFFramework_AllocLikeOp<"alloc_output", DefaultResource> {
let summary = "allocation of tensorsmemory allocation operation";
let description = [{
Allocation of output tensors during kernel execution in the Compute method.
This should be used to allocate any tensor that is going to be used as an
output from the kernel at the end of the current execution.
Defined in third_party/tensorflow/core/framework/op_kernel.cc.
}];
}
//===----------------------------------------------------------------------===//
// AllocTempOp
//===----------------------------------------------------------------------===//
def TFFramework_AllocTempOp
: TFFramework_AllocLikeOp<"alloc_temp", DefaultResource> {
let summary = "memory allocation operation";
let description = [{
Allocation of temp tensors during kernel execution in the Compute method.
This should be used to allocate any scratch storage that is needed while
the kernel is executing, and will not be retained.
Defined in third_party/tensorflow/core/framework/op_kernel.cc.
}];
}
#endif // TF_FRAMEWORK_OPS

View File

@ -0,0 +1,19 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/tools/kernel_gen:kernel-gen-opt",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,15 @@
// RUN: kernel-gen-opt %s -split-input-file -verify-diagnostics
func @alloc_output(%ctx: !tf_framework.op_kernel_context, %size : index) {
// expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}}
%buf = tf_framework.alloc_output(%ctx, %size) : memref<?x10x?xi8>
return
}
// -----
func @alloc_temp(%ctx: !tf_framework.op_kernel_context, %size : index) {
// expected-error @+1 {{`dyn_sizes` count 1 does not match dynamic dimensions}}
%buf = tf_framework.alloc_temp(%ctx, %size) : memref<10xi8>
return
}

View File

@ -0,0 +1,21 @@
// RUN: kernel-gen-opt %s | FileCheck %s
// Verify the printed output can be parsed.
// RUN: kernel-gen-opt %s | kernel-gen-opt -allow-unregistered-dialect | FileCheck %s
// Verify the generic form can be parsed.
// RUN: kernel-gen-opt -mlir-print-op-generic %s | kernel-gen-opt -allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: func @alloc_output
func @alloc_output(%ctx: !tf_framework.op_kernel_context,
%size_0 : index , %size_2 : index) {
%buf_0 = tf_framework.alloc_output(%ctx) : memref<10xi8>
%buf_1 = tf_framework.alloc_output(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
return
}
// CHECK-LABEL: func @alloc_temp
func @alloc_temp(%ctx: !tf_framework.op_kernel_context,
%size_0 : index , %size_2 : index) {
%buf_0 = tf_framework.alloc_temp(%ctx) : memref<10xi8>
%buf_1 = tf_framework.alloc_temp(%ctx, %size_0, %size_2) : memref<?x10x?xi8>
return
}