Add initial support for inlining tf operations.
This defines the barebones interfaces for inlining operations in the "tf" dialect, and inlining into executor island operations. There is currently no modeling for side effecting operations, meaning that it is not conservative in the general case. PiperOrigin-RevId: 274013484
This commit is contained in:
parent
d8265202b7
commit
444b2aced6
@ -22,6 +22,7 @@ filegroup(
|
||||
"ir/tf_op_base.td",
|
||||
"ir/tf_ops.td",
|
||||
"@local_config_mlir//:OpBaseTdFiles",
|
||||
"@local_config_mlir//:include/mlir/Analysis/CallInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -160,6 +161,8 @@ cc_library(
|
||||
"ir/tf_types.h",
|
||||
"transforms/bridge.h",
|
||||
"transforms/passes.h",
|
||||
"@local_config_mlir//:include/mlir/Analysis/CallInterfaces.h",
|
||||
"@local_config_mlir//:include/mlir/Transforms/InliningUtils.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
@ -175,6 +178,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:CallOpInterfacesIncGen",
|
||||
"@local_config_mlir//:Dialect",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Parser",
|
||||
@ -714,6 +718,7 @@ tf_native_cc_binary(
|
||||
genrule(
|
||||
name = "derived_attr_populator_inc",
|
||||
srcs = [
|
||||
"@local_config_mlir//:include/mlir/Analysis/CallInterfaces.td",
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
"ir/tf_generated_ops.td",
|
||||
"ir/tf_op_base.td",
|
||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -83,6 +84,24 @@ ShapedType DropRefType(ShapedType type) {
|
||||
|
||||
namespace {
|
||||
|
||||
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Override the inlining hook to determine if 'src' can be inlined into
|
||||
// 'dest'.
|
||||
bool isLegalToInline(Region *dest, Region *src,
|
||||
BlockAndValueMapping &value_mapping) const final {
|
||||
// Allow inlining into tf.island regions if the incoming region has a single
|
||||
// block.
|
||||
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
|
||||
std::next(src->begin()) == src->end();
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorFlowExecutorOpFolderDialectInterface
|
||||
: public OpFolderDialectInterface {
|
||||
using OpFolderDialectInterface::OpFolderDialectInterface;
|
||||
@ -106,7 +125,8 @@ TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
|
||||
>();
|
||||
|
||||
addInterfaces<TensorFlowExecutorOpFolderDialectInterface>();
|
||||
addInterfaces<TensorFlowExecutorInlinerInterface,
|
||||
TensorFlowExecutorOpFolderDialectInterface>();
|
||||
|
||||
addTypes<ControlType, TokenType>();
|
||||
}
|
||||
|
@ -2903,31 +2903,6 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
|
||||
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
returns `f(inputs)`, where `f`'s body is placed and partitioned.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$args,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
StrAttr:$config,
|
||||
StrAttr:$config_proto,
|
||||
StrAttr:$executor_type
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_PowOp : TF_Op<"Pow", [Broadcastable, NoSideEffect]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Computes the power of one value to another.";
|
||||
|
@ -47,6 +47,7 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/STLExtras.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -1404,6 +1405,47 @@ void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Defines the legality of inlining TF operations.
|
||||
bool isLegalToInline(Operation *, Region *,
|
||||
BlockAndValueMapping &) const final {
|
||||
// TODO(riverriddle) For now, enable inlining all operations. This isn't
|
||||
// correct in the face of operations that cannot be duplicated, but this
|
||||
// requires more intricate side-effect modeling.
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Attempts to materialize a conversion for a type mismatch between a call
|
||||
// from this dialect, and a callable region. This method should generate an
|
||||
// operation that takes 'input' as the only operand, and produces a single
|
||||
// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
// should be returned.
|
||||
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
|
||||
Type result_type,
|
||||
Location conversion_loc) const final {
|
||||
if (!result_type.isa<TensorType>() || !input->getType().isa<TensorType>())
|
||||
return nullptr;
|
||||
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
|
||||
/*truncate=*/builder.getBoolAttr(false));
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1419,6 +1461,7 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
>();
|
||||
addInterfaces<TFInlinerInterface>();
|
||||
|
||||
// Support unknown operations because not all TensorFlow operations are
|
||||
// registered.
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
|
||||
|
||||
#include "mlir/Analysis/CallInterfaces.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
|
@ -30,6 +30,11 @@ limitations under the License.
|
||||
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
|
||||
|
||||
#ifdef MLIR_CALLINTERFACES
|
||||
#else
|
||||
include "mlir/Analysis/CallInterfaces.td"
|
||||
#endif // MLIR_CALLINTERFACES
|
||||
|
||||
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
@ -196,6 +201,44 @@ retained with length 1.
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_PartitionedCallOp : TF_Op<"PartitionedCall",
|
||||
[CallOpInterface, NoSideEffect]> {
|
||||
let summary =
|
||||
"returns `f(inputs)`, where `f`'s body is placed and partitioned.";
|
||||
|
||||
let description = [{
|
||||
Asynchronously executes a function, potentially across multiple devices but
|
||||
within a single process. The kernel places and partitions a given function's
|
||||
underlying graph, and executes each of the partitioned subgraphs as a function.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$args,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
StrAttr:$config,
|
||||
StrAttr:$config_proto,
|
||||
StrAttr:$executor_type
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Gets the argument operands to the called function.
|
||||
operand_range getArgOperands() { return args(); }
|
||||
|
||||
// Returns the callee of this operation.
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>("f");
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// In MLIR, the 'tf.Placeholder.input' instruction is used to capture attributes
|
||||
// of function arguments.
|
||||
// Note: NoSideEffect trait is not added intentionally to preserve the captured
|
||||
@ -240,6 +283,44 @@ Inserts a placeholder for a tensor that will be always fed.
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall",
|
||||
[CallOpInterface]> {
|
||||
let summary =
|
||||
"returns `f(inputs)`, where `f`'s body is placed and partitioned.";
|
||||
|
||||
let description = [{
|
||||
Asynchronously executes a function, potentially across multiple devices but
|
||||
within a single process. The kernel places and partitions a given function's
|
||||
underlying graph, and executes each of the partitioned subgraphs as a function.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$args,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
StrAttr:$config,
|
||||
StrAttr:$config_proto,
|
||||
StrAttr:$executor_type
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Gets the argument operands to the called function.
|
||||
operand_range getArgOperands() { return args(); }
|
||||
|
||||
// Returns the callee of this operation.
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>("f");
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_WhileOp : TF_Op<"While", []> {
|
||||
let summary = [{
|
||||
output = input; While (Cond(output)) { output = Body(output) }
|
||||
|
64
tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir
Normal file
64
tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir
Normal file
@ -0,0 +1,64 @@
|
||||
// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test that simple TF operations can be inlined.
|
||||
|
||||
func @inline_simple_callee() -> tensor<2xi32> {
|
||||
%cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
return %cst : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @inline_simple(
|
||||
func @inline_simple() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = "tf.Const"
|
||||
// CHECK-NEXT: return %[[CST]]
|
||||
%result = "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @inline_simple_callee} : () -> tensor<2xi32>
|
||||
return %result : tensor<2xi32>
|
||||
}
|
||||
|
||||
// Check that TF call operations can be inlined, even when the shape of the
|
||||
// argument or result is different than the called function.
|
||||
|
||||
func @inline_shape_cast_callee(%arg : tensor<*xi32>) -> tensor<*xi32> {
|
||||
return %arg : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @inline_shape_cast(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi32>
|
||||
func @inline_shape_cast(%arg: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %[[ARG_CAST:.*]] = "tf.Cast"(%[[ARG]]) {Truncate = false} : (tensor<2xi32>) -> tensor<*xi32>
|
||||
// CHECK-NEXT: %[[RESULT_CAST:.*]] = "tf.Cast"(%[[ARG_CAST]]) {Truncate = false} : (tensor<*xi32>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: return %[[RESULT_CAST]]
|
||||
%result = "tf.PartitionedCall"(%arg) {config = "", config_proto = "", executor_type = "", f = @inline_shape_cast_callee} : (tensor<2xi32>) -> tensor<2xi32>
|
||||
return %result : tensor<2xi32>
|
||||
}
|
||||
|
||||
// Check that functions can be inlined into islands.
|
||||
|
||||
func @inline_into_island_multi_block_callee() -> tensor<2xi32> {
|
||||
br ^bb1
|
||||
|
||||
^bb1:
|
||||
%cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
return %cst : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @inline_into_island(
|
||||
func @inline_into_island() -> (tensor<2xi32>, tensor<2xi32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
%1:3 = tf_executor.island {
|
||||
// Single block regions may be inlined.
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"
|
||||
%result = "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @inline_simple_callee} : () -> tensor<2xi32>
|
||||
|
||||
// Multi block regions may not.
|
||||
// CHECK-NEXT: %[[CALL:.*]] = "tf.StatefulPartitionedCall"
|
||||
%result_2 = "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @inline_into_island_multi_block_callee} : () -> tensor<2xi32>
|
||||
|
||||
// CHECK-NEXT: tf_executor.yield %[[CST]], %[[CALL]]
|
||||
tf_executor.yield %result, %result_2 : tensor<2xi32>, tensor<2xi32>
|
||||
}
|
||||
tf_executor.fetch %1#1, %1#1 : tensor<2xi32>, tensor<2xi32>
|
||||
}
|
||||
return %0#1, %0#1 : tensor<2xi32>, tensor<2xi32>
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
func @main() {
|
||||
%0:2 = "_tf.VarHandleOp"() {dtype = "tfdtype$DT_FLOAT", shape = "tfshape$"} : () -> (tensor<!tf.resource>, !_tf.control)
|
||||
%1:2 = "_tf.StatefulPartitionedCall"(%0#0) {Tin = ["tfdtype$DT_RESOURCE"], Tout = ["tfdtype$DT_RESOURCE"], f = @foo} : (tensor<!tf.resource>) -> (tensor<!tf.resource>, !_tf.control) loc("call_foo")
|
||||
%1:2 = "_tf.StatefulPartitionedCall"(%0#0) {Tin = ["tfdtype$DT_RESOURCE"], Tout = ["tfdtype$DT_RESOURCE"], config = "", config_proto = "", executor_type = "", f = @foo} : (tensor<!tf.resource>) -> (tensor<!tf.resource>, !_tf.control) loc("call_foo")
|
||||
return
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user