Add initial support for inlining tf operations.

This defines the barebones interfaces for inlining operations in the "tf" dialect, and inlining into executor island operations. There is currently no modeling for side effecting operations, meaning that it is not conservative in the general case.

PiperOrigin-RevId: 274013484
This commit is contained in:
River Riddle 2019-10-10 12:24:16 -07:00 committed by TensorFlower Gardener
parent d8265202b7
commit 444b2aced6
8 changed files with 216 additions and 27 deletions

View File

@ -22,6 +22,7 @@ filegroup(
"ir/tf_op_base.td",
"ir/tf_ops.td",
"@local_config_mlir//:OpBaseTdFiles",
"@local_config_mlir//:include/mlir/Analysis/CallInterfaces.td",
],
)
@ -160,6 +161,8 @@ cc_library(
"ir/tf_types.h",
"transforms/bridge.h",
"transforms/passes.h",
"@local_config_mlir//:include/mlir/Analysis/CallInterfaces.h",
"@local_config_mlir//:include/mlir/Transforms/InliningUtils.h",
],
includes = ["include"],
deps = [
@ -175,6 +178,7 @@ cc_library(
"//tensorflow/core:lib",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:CallOpInterfacesIncGen",
"@local_config_mlir//:Dialect",
"@local_config_mlir//:IR",
"@local_config_mlir//:Parser",
@ -714,6 +718,7 @@ tf_native_cc_binary(
genrule(
name = "derived_attr_populator_inc",
srcs = [
"@local_config_mlir//:include/mlir/Analysis/CallInterfaces.td",
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"ir/tf_generated_ops.td",
"ir/tf_op_base.td",

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Transforms/FoldUtils.h" // TF:local_config_mlir
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
@ -83,6 +84,24 @@ ShapedType DropRefType(ShapedType type) {
namespace {
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
// Override the inlining hook to determine if 'src' can be inlined into
// 'dest'.
bool isLegalToInline(Region *dest, Region *src,
BlockAndValueMapping &value_mapping) const final {
// Allow inlining into tf.island regions if the incoming region has a single
// block.
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
std::next(src->begin()) == src->end();
}
};
struct TensorFlowExecutorOpFolderDialectInterface
: public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
@ -106,7 +125,8 @@ TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context)
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc"
>();
addInterfaces<TensorFlowExecutorOpFolderDialectInterface>();
addInterfaces<TensorFlowExecutorInlinerInterface,
TensorFlowExecutorOpFolderDialectInterface>();
addTypes<ControlType, TokenType>();
}

View File

@ -2903,31 +2903,6 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
}
def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [NoSideEffect]> {
let summary = [{
returns `f(inputs)`, where `f`'s body is placed and partitioned.
}];
let description = [{
}];
let arguments = (ins
Variadic<TF_Tensor>:$args,
SymbolRefAttr:$f,
StrAttr:$config,
StrAttr:$config_proto,
StrAttr:$executor_type
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
}
def TF_PowOp : TF_Op<"Pow", [Broadcastable, NoSideEffect]>,
WithBroadcastableBinOpBuilder {
let summary = "Computes the power of one value to another.";

View File

@ -47,6 +47,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Support/STLExtras.h" // TF:local_config_mlir
#include "mlir/Transforms/InliningUtils.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
@ -1404,6 +1405,47 @@ void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc"
//===----------------------------------------------------------------------===//
// TF Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TFInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
// Defines the legality of inlining TF operations.
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {
// TODO(riverriddle) For now, enable inlining all operations. This isn't
// correct in the face of operations that cannot be duplicated, but this
// requires more intricate side-effect modeling.
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
// Attempts to materialize a conversion for a type mismatch between a call
// from this dialect, and a callable region. This method should generate an
// operation that takes 'input' as the only operand, and produces a single
// result of 'resultType'. If a conversion can not be generated, nullptr
// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
Type result_type,
Location conversion_loc) const final {
if (!result_type.isa<TensorType>() || !input->getType().isa<TensorType>())
return nullptr;
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
/*truncate=*/builder.getBoolAttr(false));
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// TF Dialect
//===----------------------------------------------------------------------===//
@ -1419,6 +1461,7 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
>();
addInterfaces<TFInlinerInterface>();
// Support unknown operations because not all TensorFlow operations are
// registered.

View File

@ -19,6 +19,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_
#include "mlir/Analysis/CallInterfaces.h" // TF:local_config_mlir
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir

View File

@ -30,6 +30,11 @@ limitations under the License.
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
#ifdef MLIR_CALLINTERFACES
#else
include "mlir/Analysis/CallInterfaces.td"
#endif // MLIR_CALLINTERFACES
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
let results = (outs
TF_VariantTensor:$handle
@ -196,6 +201,44 @@ retained with length 1.
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
}
def TF_PartitionedCallOp : TF_Op<"PartitionedCall",
[CallOpInterface, NoSideEffect]> {
let summary =
"returns `f(inputs)`, where `f`'s body is placed and partitioned.";
let description = [{
Asynchronously executes a function, potentially across multiple devices but
within a single process. The kernel places and partitions a given function's
underlying graph, and executes each of the partitioned subgraphs as a function.
}];
let arguments = (ins
Variadic<TF_Tensor>:$args,
SymbolRefAttr:$f,
StrAttr:$config,
StrAttr:$config_proto,
StrAttr:$executor_type
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let extraClassDeclaration = [{
// Gets the argument operands to the called function.
operand_range getArgOperands() { return args(); }
// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("f");
}
}];
}
// In MLIR, the 'tf.Placeholder.input' instruction is used to capture attributes
// of function arguments.
// Note: NoSideEffect trait is not added intentionally to preserve the captured
@ -240,6 +283,44 @@ Inserts a placeholder for a tensor that will be always fed.
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall",
[CallOpInterface]> {
let summary =
"returns `f(inputs)`, where `f`'s body is placed and partitioned.";
let description = [{
Asynchronously executes a function, potentially across multiple devices but
within a single process. The kernel places and partitions a given function's
underlying graph, and executes each of the partitioned subgraphs as a function.
}];
let arguments = (ins
Variadic<TF_Tensor>:$args,
SymbolRefAttr:$f,
StrAttr:$config,
StrAttr:$config_proto,
StrAttr:$executor_type
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let extraClassDeclaration = [{
// Gets the argument operands to the called function.
operand_range getArgOperands() { return args(); }
// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("f");
}
}];
}
def TF_WhileOp : TF_Op<"While", []> {
let summary = [{
output = input; While (Cond(output)) { output = Body(output) }

View File

@ -0,0 +1,64 @@
// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail
// Test that simple TF operations can be inlined.
func @inline_simple_callee() -> tensor<2xi32> {
%cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32>
return %cst : tensor<2xi32>
}
// CHECK-LABEL: func @inline_simple(
func @inline_simple() -> tensor<2xi32> {
// CHECK-NEXT: %[[CST:.*]] = "tf.Const"
// CHECK-NEXT: return %[[CST]]
%result = "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @inline_simple_callee} : () -> tensor<2xi32>
return %result : tensor<2xi32>
}
// Check that TF call operations can be inlined, even when the shape of the
// argument or result is different than the called function.
func @inline_shape_cast_callee(%arg : tensor<*xi32>) -> tensor<*xi32> {
return %arg : tensor<*xi32>
}
// CHECK-LABEL: func @inline_shape_cast(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi32>
func @inline_shape_cast(%arg: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %[[ARG_CAST:.*]] = "tf.Cast"(%[[ARG]]) {Truncate = false} : (tensor<2xi32>) -> tensor<*xi32>
// CHECK-NEXT: %[[RESULT_CAST:.*]] = "tf.Cast"(%[[ARG_CAST]]) {Truncate = false} : (tensor<*xi32>) -> tensor<2xi32>
// CHECK-NEXT: return %[[RESULT_CAST]]
%result = "tf.PartitionedCall"(%arg) {config = "", config_proto = "", executor_type = "", f = @inline_shape_cast_callee} : (tensor<2xi32>) -> tensor<2xi32>
return %result : tensor<2xi32>
}
// Check that functions can be inlined into islands.
func @inline_into_island_multi_block_callee() -> tensor<2xi32> {
br ^bb1
^bb1:
%cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32>
return %cst : tensor<2xi32>
}
// CHECK-LABEL: func @inline_into_island(
func @inline_into_island() -> (tensor<2xi32>, tensor<2xi32>) {
%0:2 = tf_executor.graph {
%1:3 = tf_executor.island {
// Single block regions may be inlined.
// CHECK: %[[CST:.*]] = "tf.Const"
%result = "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @inline_simple_callee} : () -> tensor<2xi32>
// Multi block regions may not.
// CHECK-NEXT: %[[CALL:.*]] = "tf.StatefulPartitionedCall"
%result_2 = "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @inline_into_island_multi_block_callee} : () -> tensor<2xi32>
// CHECK-NEXT: tf_executor.yield %[[CST]], %[[CALL]]
tf_executor.yield %result, %result_2 : tensor<2xi32>, tensor<2xi32>
}
tf_executor.fetch %1#1, %1#1 : tensor<2xi32>, tensor<2xi32>
}
return %0#1, %0#1 : tensor<2xi32>, tensor<2xi32>
}

View File

@ -2,7 +2,7 @@
func @main() {
%0:2 = "_tf.VarHandleOp"() {dtype = "tfdtype$DT_FLOAT", shape = "tfshape$"} : () -> (tensor<!tf.resource>, !_tf.control)
%1:2 = "_tf.StatefulPartitionedCall"(%0#0) {Tin = ["tfdtype$DT_RESOURCE"], Tout = ["tfdtype$DT_RESOURCE"], f = @foo} : (tensor<!tf.resource>) -> (tensor<!tf.resource>, !_tf.control) loc("call_foo")
%1:2 = "_tf.StatefulPartitionedCall"(%0#0) {Tin = ["tfdtype$DT_RESOURCE"], Tout = ["tfdtype$DT_RESOURCE"], config = "", config_proto = "", executor_type = "", f = @foo} : (tensor<!tf.resource>) -> (tensor<!tf.resource>, !_tf.control) loc("call_foo")
return
}