Splitting up tf_ops

Given the switch intended for SOT, just split these alphabetically.

PiperOrigin-RevId: 321465891
Change-Id: Ie261af7b2bd4a5a2dd80fc3865ed3159f0be155d
This commit is contained in:
Jacques Pienaar 2020-07-15 16:55:57 -07:00 committed by TensorFlower Gardener
parent 093491a562
commit 0d38f3e81c
12 changed files with 5077 additions and 4503 deletions

View File

@ -88,6 +88,7 @@ gentbl(
cc_library(
name = "tensorflow_op_interfaces",
srcs = [
"ir/tf_op_interfaces.cc",
"ir/tf_op_interfaces.cc.inc",
"ir/tf_op_interfaces.h.inc",
"ir/tf_verifiers.cc",
@ -105,15 +106,67 @@ cc_library(
)
gentbl(
name = "tensorflow_ops_inc_gen",
name = "tensorflow_all_ops_inc_gen",
tbl_outs = [
(
"-gen-op-decls",
"ir/tf_ops.h.inc",
"ir/tf_all_ops.h.inc",
),
(
"-gen-op-defs",
"ir/tf_ops.cc.inc",
"ir/tf_all_ops.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tf_ops.td",
td_srcs = [
":tensorflow_ops_td_files",
],
)
# We only shard tf_op on name for build performance reasons.
tf_ops_category_list = [
{
"name": "ops_a_m",
"include": "tf.[A-M].*$$",
},
{
"name": "ops_n_z",
"include": "tf.[N-Z].*$$",
},
]
[[
gentbl(
name = "tensorflow_" + target["name"] + "_inc_gen",
tbl_outs = [
(
"-gen-op-decls -op-include-regex='" + target["include"] + "'",
"ir/tf_" + target["name"] + ".h.inc",
),
(
"-gen-op-defs -op-include-regex='" + target["include"] + "'",
"ir/tf_" + target["name"] + ".cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tf_ops.td",
td_srcs = [
":tensorflow_ops_td_files",
],
),
] for target in tf_ops_category_list]
gentbl(
name = "tensorflow_remaining_ops_inc_gen",
tbl_outs = [
(
"-gen-op-decls -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ",
"ir/tf_remaining_ops.h.inc",
),
(
"-gen-op-defs -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ",
"ir/tf_remaining_ops.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
@ -179,7 +232,7 @@ gentbl(
name = "tensorflow_device_ops_inc_gen",
tbl_outs = [
(
"-gen-op-decls",
"-gen-op-decls ",
"ir/tf_device.h.inc",
),
(
@ -284,24 +337,67 @@ cc_library(
],
)
[[
cc_library(
name = "tensorflow_" + target["name"],
srcs = [
"ir/tf_ops.h",
"ir/tf_remaining_ops.h",
"ir/tf_" + target["name"] + ".cc",
"ir/tf_" + target["name"] + ".cc.inc",
] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list],
hdrs = [
],
textual_hdrs = [
"ir/tf_all_ops.h.inc",
"ir/tf_ops_helpers.inc",
"ir/tf_remaining_ops.h.inc",
] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list],
deps = [
":tensorflow_attributes",
":tensorflow_canonicalize_inc_gen",
":tensorflow_op_interfaces",
":tensorflow_op_interfaces_inc_gen",
":tensorflow_side_effects",
":tensorflow_structs",
":tensorflow_traits",
":tensorflow_types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
] + [":tensorflow_" + target["name"] + "_inc_gen"],
),
] for target in tf_ops_category_list]
cc_library(
name = "tensorflow_ops",
name = "tensorflow_remaining_ops",
srcs = [
"ir/tf_ops.cc",
"ir/tf_ops.cc.inc",
"ir/tf_ops.h",
],
"ir/tf_remaining_ops.h",
"ir/tf_remaining_ops.cc",
] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list],
hdrs = [
],
textual_hdrs = [
"ir/tf_ops.h.inc",
],
"ir/tf_all_ops.h.inc",
"ir/tf_ops_helpers.inc",
"ir/tf_remaining_ops.h.inc",
] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list],
deps = [
":tensorflow_attributes",
":tensorflow_canonicalize_inc_gen",
":tensorflow_op_interfaces",
":tensorflow_op_interfaces_inc_gen",
":tensorflow_ops_inc_gen",
":tensorflow_remaining_ops_inc_gen",
":tensorflow_side_effects",
":tensorflow_structs",
":tensorflow_traits",
@ -321,6 +417,43 @@ cc_library(
],
)
cc_library(
name = "tensorflow_ops",
srcs = [
"ir/tf_ops.cc",
"ir/tf_ops.h",
],
textual_hdrs = [
"ir/tf_all_ops.h.inc",
"ir/tf_remaining_ops.h",
] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list],
deps = [
":tensorflow_all_ops_inc_gen",
":tensorflow_remaining_ops_inc_gen",
":tensorflow_attributes",
":tensorflow_canonicalize_inc_gen",
":tensorflow_op_interfaces",
":tensorflow_op_interfaces_inc_gen",
":tensorflow_side_effects",
":tensorflow_structs",
":tensorflow_traits",
":tensorflow_types",
":tensorflow_remaining_ops",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
] + [":tensorflow_" + target["name"] for target in tf_ops_category_list],
)
cc_library(
name = "tensorflow_structs",
srcs = [
@ -393,12 +526,14 @@ cc_library(
includes = ["include"],
deps = [
":error_util",
":tensorflow_all_ops_inc_gen",
":tensorflow_attributes",
":tensorflow_canonicalize_inc_gen",
":tensorflow_device_ops_inc_gen",
":tensorflow_executor_inc_gen",
":tensorflow_op_interfaces",
":tensorflow_ops",
":tensorflow_side_effects",
":tensorflow_structs",
":tensorflow_traits",
":tensorflow_types",

View File

@ -332,7 +332,7 @@ class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
// This returns a list of shapes so it is used for variadic operands that
// can have different shapes.
class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
"mlir::TF::OperandShapeRange",
"::mlir::TF::OperandShapeRange",
"auto values = getODSOperands(" # idx # ");\n"
"return {mlir::TF::OperandShapeIterator(values.begin()), "
"mlir::TF::OperandShapeIterator(values.end())};",

View File

@ -0,0 +1,20 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
namespace mlir::TF {
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc"
} // namespace mlir::TF

File diff suppressed because it is too large Load Diff

View File

@ -35,6 +35,9 @@ limitations under the License.
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@ -112,17 +115,6 @@ class TensorFlowDialect : public Dialect {
static std::vector<AdditionalOpFunction> *additional_operation_hooks_;
};
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose
// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use
// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and
// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to
// undefine here to avoid expanding the getter symbol as macro when including
// both mutex.h and this header file.
#undef mutex_lock
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
} // namespace TF
} // namespace mlir

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
namespace mlir {
namespace TF {
class YieldOp;
// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose
// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use
// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and
// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to
// undefine here to avoid expanding the getter symbol as macro when including
// both mutex.h and this header file.
#undef mutex_lock
#define GET_OP_FWD_DEFINES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h.inc"
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_

View File

@ -0,0 +1,580 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is a simple include file used to simplify the splitting of the
// tf_ops.cc file. The helpers in here should be refactored and moved to
// tf_verifiers or tf_ops.
// TODO(jpienaar): Remove this file post refactoring.
// Propagates underscore and device attributes from src to dst.
// TODO(b/158769932): This should be a general feature instead post some policy
// discussion.
static void PropagateAttributes(Operation *src, Operation *dst) {
auto device = mlir::Identifier::get("device", src->getContext());
for (auto named_attr : src->getAttrs()) {
if (*named_attr.first.begin() == '_' || named_attr.first == device)
dst->setAttr(named_attr.first, named_attr.second);
}
}
//===----------------------------------------------------------------------===//
// TF op helper functions
//===----------------------------------------------------------------------===//
// Returns the RankedTensorType for the given operand. TensorFlow constant ops
// may have non-static shape because the shape is not propagated during constant
// folding. If the defining op for the given operand is a constant op, this
// routine uses the constant op's attribute to get the actual shape.
static RankedTensorType GetRankedTensorTypeForOperand(Value operand) {
DenseElementsAttr attr;
if (matchPattern(operand, m_Constant(&attr))) {
return attr.getType().dyn_cast<RankedTensorType>();
}
return operand.getType().dyn_cast<RankedTensorType>();
}
// Returns true if the given `value` is of ranked float tensor type with the
// given `rank`.
static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) {
return type && type.getRank() == rank &&
type.getElementType().isa<FloatType>();
}
// Returns true if the given `value` has the specified rank or has unranked
// type.
static inline bool IsOfRankOrUnranked(Value value, int64_t rank) {
RankedTensorType type = GetRankedTensorTypeForOperand(value);
return !type || type.getRank() == rank;
}
// Returns true if the given `value` has at least the specified rank or has
// unranked type.
static inline bool HasRankAtLeast(Value value, int64_t rank) {
RankedTensorType type = GetRankedTensorTypeForOperand(value);
return !type || type.getRank() >= rank;
}
// Returns true if the given `value` has at most the specified rank or has
// unranked type.
static inline bool HasRankAtMost(Value value, int64_t rank) {
RankedTensorType type = GetRankedTensorTypeForOperand(value);
return !type || type.getRank() <= rank;
}
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
return dim_or_rank == -1;
}
// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
// `incompatible_shape_error` is true, reports error if `x` and `y` has
// incompatible shapes. Otherwise, returns a tensor type with unknown rank.
static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x,
Value y, BoolAttr incompatible_shape_error) {
auto result_type =
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!result_type) {
if (incompatible_shape_error.getValue()) {
mlir::emitError(loc, "non-broadcastable operands");
} else {
return UnrankedTensorType::get(builder->getI1Type());
}
}
auto ranked_type = result_type.dyn_cast<RankedTensorType>();
if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type());
return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type());
}
// Returns dimension index for the given TensorFlow axis that supports negative
// indexing.
static int64_t GetDimForAxis(int64_t axis, int64_t rank) {
return axis >= 0 ? axis : axis + rank;
}
// Infers output type for reduction ops such as SumOp, MaxOp etc.
// TODO(b/e667204a): Move this logic to shape inference once it supports custom
// inference functions.
static Type InferReductionOpType(Value input, Value reduction_indices,
BoolAttr keep_dims, Builder *builder) {
Type input_ty = input.getType();
Type element_ty = getElementTypeOrSelf(input_ty);
// Output type is unranked if input type is not ranked.
auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) return UnrankedTensorType::get(element_ty);
int64_t rank = ranked_ty.getRank();
DenseIntElementsAttr indices;
if (!matchPattern(reduction_indices, m_Constant(&indices))) {
// Output type is unranked if reduction indices are not constant and reduced
// dimensions are not kept.
if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty);
// Otherwise, output type has same rank as the input.
return RankedTensorType::get(SmallVector<int64_t, 4>(rank, -1), element_ty);
}
int64_t num_reduce_dim = 0;
llvm::SmallVector<bool, 4> is_reduce_dim(rank, false);
for (const APInt &index : indices.getValues<APInt>()) {
int64_t dim = GetDimForAxis(index.getSExtValue(), rank);
// Invalid input.
if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty);
if (!is_reduce_dim[dim]) {
is_reduce_dim[dim] = true;
num_reduce_dim++;
}
}
ArrayRef<int64_t> shape = ranked_ty.getShape();
SmallVector<int64_t, 4> out_shape;
out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim));
for (int64_t i = 0; i < rank; ++i) {
if (!is_reduce_dim[i])
out_shape.push_back(shape[i]);
else if (keep_dims.getValue())
out_shape.push_back(1);
}
return RankedTensorType::get(out_shape, element_ty);
}
// Verifies that the given types are cast compatible. If not, emits appropriate
// error for the given op. If mask_one_dim is set to true, then the types are
// allowed to have one mismatching dimension. Masking one of the dimensions is
// useful for ops like Concat that requires all ranked inputs to have the same
// rank and match dimension sizes for all but one of the dimensions.
static LogicalResult VerifyTypesCompatibility(
Operation::operand_type_range types, bool mask_one_dim, Operation *op) {
constexpr int64_t kUninitialized = -1;
int64_t common_rank = kUninitialized;
llvm::SmallVector<int64_t, 4> common_dims;
int64_t dim_to_mask = kUninitialized;
// Initialize common_rank with rank of the first ranked type and verify that
// following ranked types have the same rank.
// Similarly, initialize each of the dimensions with the first type that has
// the dimension size available and verify that all following types have the
// same size for the dimension. However, if mask_one_dim is true, note down
// the dimension index on the first mismatch and ignore dimension at that
// index in following types.
for (Type ty : types) {
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) continue;
int64_t rank = ranked_ty.getRank();
if (common_rank == kUninitialized) {
common_rank = rank;
common_dims.resize(common_rank, kUninitialized);
} else if (common_rank != rank) {
return op->emitError()
<< "operand type " << ranked_ty
<< " is not compatible with preceding operands; expected rank: "
<< common_rank;
}
for (int64_t i = 0, e = common_rank; i != e; i++) {
if (i == dim_to_mask) continue;
int64_t dim = ranked_ty.getDimSize(i);
if (dim == kUninitialized) continue;
int64_t &common_dim = common_dims[i];
if (common_dim == kUninitialized) {
common_dim = dim;
} else if (common_dim != dim) {
// If mask_one_dim is true, do not emit an error if this is the only
// dimension with mismatches. Note down the dimension to mask it from
// the following types.
if (mask_one_dim && dim_to_mask == kUninitialized) {
dim_to_mask = i;
continue;
}
return op->emitError() << "operand type " << ranked_ty
<< " is not compatible with preceding operands; "
"expected dimension at index "
<< i << ": " << common_dim;
}
}
}
return success();
}
// This is a helper for the Select to SelectV2 canonicalization. The `data` rank
// refers to the rank of `t`/`e` (these two inputs have equal rank; this is
// checked in the verifier).
//
// In most cases, the predicate for Select can be used directly as the predicate
// for SelectV2. However, there is one case that varies, which is when the
// predicate is a tensor and the data is multidimensional. In this case, Select
// op semantics dictate that the predicate tensor length must match the size of
// the first data dimension. This varies from normal broadcasting semantics
// (which are used in SelectV2), so we must reshape the tensor in this case to
// be compatible.
static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc,
Value cond, int data_rank) {
auto cond_tensor = cond.getType().cast<RankedTensorType>();
// Reshape is only needed in the case that the cond rank is 1 (i.e. it is
// a vector) AND t/e rank is > 1.
if (cond_tensor.getRank() != 1 || data_rank <= 1) {
// No reshape necessary. Leave cond as it is.
return cond;
}
// This is the case where a reshape is needed. We want to construct the
// shape [x,1,...1], where x is the value in the pred tensor and the
// length of the shape is equal to data_rank.
SmallVector<int64_t, 8> shape(data_rank, 1);
shape[0] = cond_tensor.getShape().front();
auto new_shape_type =
RankedTensorType::get({data_rank}, builder->getIntegerType(64));
auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape);
auto new_shape = builder->create<ConstOp>(loc, shape_attr);
return builder->create<ReshapeOp>(loc, cond, new_shape);
}
//===----------------------------------------------------------------------===//
// Helper functions detect device capabilities from RuntimeDevices.
//===----------------------------------------------------------------------===//
namespace {
using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) {
return device.type == ::tensorflow::DEVICE_GPU;
}
} // namespace
// Returns true if at least one GPU device is available at runtime.
bool CanUseGpuDevice(const RuntimeDevices &devices) {
return llvm::any_of(devices.device_names(), IsGpuDevice);
}
// Returns true if all of the GPUs available at runtime support TensorCores
// (NVIDIA compute capability >= 7.0).
bool CanUseTensorCores(const RuntimeDevices &devices) {
auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) {
auto md = devices.GetGpuDeviceMetadata(device);
return md ? md->cc_major().getInt() >= 7 : false;
};
return llvm::all_of(
llvm::make_filter_range(devices.device_names(), IsGpuDevice),
has_tensor_cores);
}
// Returns true if operation does not have explicit device placement that would
// prevent it from running on GPU device.
bool CanUseGpuDevice(Operation *op) {
auto device_attr = op->getAttrOfType<StringAttr>("device");
if (!device_attr || device_attr.getValue().empty()) return true;
DeviceNameUtils::ParsedName device;
if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device))
return false;
// We can't use GPU if operation explicitly placed on non-GPU device.
return !device.has_type || device.type == ::tensorflow::DEVICE_GPU;
}
//===----------------------------------------------------------------------===//
// TF op helper functions to work with layout transformation.
//===----------------------------------------------------------------------===//
SmallVector<int64_t, 4> ReversePermutation(ArrayRef<int64_t> permutation) {
SmallVector<int64_t, 4> reverse(permutation.size());
for (size_t i = 0; i < permutation.size(); ++i) {
reverse[permutation[i]] = i;
}
return reverse;
}
SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to) {
if (from == "NHWC" && to == "NCHW") {
return {0, 3, 1, 2};
} else if (from == "NCHW" && to == "NHWC") {
return {0, 2, 3, 1};
} else {
return {};
}
}
// Shuffle elements in the `attr` according to the permutation. Optional
// `inner_size` allows to shuffle array attributes created from rank 2 tensors
// on outer dimension only.
ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation,
int inner_size = 1) {
if (attr.size() == 0) return attr;
assert(attr.size() % inner_size == 0);
assert(attr.size() / inner_size == permutation.size());
SmallVector<Attribute, 8> values{attr.begin(), attr.end()};
SmallVector<Attribute, 8> shuffled(values.size());
for (size_t i = 0; i < permutation.size(); ++i) {
for (size_t j = 0; j < inner_size; ++j) {
shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j];
}
}
return ArrayAttr::get(shuffled, attr.getContext());
}
// Shuffle ranked tensor dimensions according to the permutation.
Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation) {
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
ArrayRef<int64_t> shape = ranked_type.getShape();
assert(permutation.size() == shape.size());
SmallVector<int64_t, 4> new_shape(permutation.size());
for (size_t i = 0; i < permutation.size(); ++i)
new_shape[i] = shape[permutation[i]];
return RankedTensorType::get(new_shape, ranked_type.getElementType());
}
return type;
}
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
DenseIntElementsAttr perm1) {
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
if (perm0.getNumElements() != perm1.getNumElements()) return false;
SmallVector<int64_t, 8> perm0_values;
for (const auto &value : perm0.getIntValues())
perm0_values.push_back(value.getSExtValue());
SmallVector<int64_t, 8> perm1_values;
for (const auto &value : perm1.getIntValues())
perm1_values.push_back(value.getSExtValue());
for (int i = 0; i < perm0_values.size(); ++i) {
if (perm0_values[perm1_values[i]] != i) return false;
}
return true;
}
// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for
// layout sensitive operations that do not have any additional layout dependent
// attributes besides `data_format` string.
template <typename Op>
LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
auto perm = GetDataFormatPermutation(op->data_format(), data_format);
if (perm.empty()) return failure();
// Update data format attribute.
op->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
// Update types for all layout sensitive results.
auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) {
OpResult result = op->getOperation()->getResult(idx);
result.setType(ShuffleRankedTensorType(result.getType(), perm));
}
return success();
}
// Default implementation for folding operand transpose into the operation.
// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`.
template <typename Op>
LogicalResult FoldOperandsPermutation(
ArrayRef<int64_t> permutation, Op *op,
ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
MLIRContext *context = op->template getParentOfType<ModuleOp>().getContext();
// We only support NHWC <-> NCHW permutations.
static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2};
// Operation data format after folding `permutation`.
StringRef target_data_format = [&]() -> StringRef {
if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) {
return "NCHW"; // cancel NCHW->NHWC operand permutation
} else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) {
return "NHWC"; // cancel NHWC->NCHW operand permutation
} else {
return "";
}
}();
if (target_data_format.empty()) return failure();
// To fold operand `permutation` into the `op` we need shuffle all layout
// dependent attributes and types with a reverse permutation, and change
// operation data format to `target_data_format`.
//
// Example:
// %1 = SomeOp(...) {data_format = NHWC}
// %2 = Transpose(%1) {permutation = NHWC->NCHW}
// %3 = Op(%2) {data_format = NCHW}
//
// To bypass %2 we have to change data format to shuffle data format from NCHW
// to NHWC, which is the reverse of operand permutation (function argument).
auto reverse_permutation =
GetDataFormatPermutation(op->data_format(), target_data_format);
if (reverse_permutation.empty()) return failure();
op->setAttr("data_format", StringAttr::get(target_data_format, context));
for (auto pair : shuffle_attrs) {
StringRef attr_name = pair.first;
ArrayAttr attr_value = pair.second;
op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
}
auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
for (unsigned idx : fold.GetLayoutDependentResults()) {
OpResult result = op->getOperation()->getResult(idx);
result.setType(
ShuffleRankedTensorType(result.getType(), reverse_permutation));
}
return success();
}
//===----------------------------------------------------------------------===//
// Rewrite Pattern for removing trivial Arithmetic op.
//===----------------------------------------------------------------------===//
namespace {
// Fold Arithmetic Op if one of the operands is a constant known to be an
// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if
// known identity value is either lhs or rhs.
template <
typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
ArrayRef<Attribute> operands) {
auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
auto result_type =
arithmetic_op.getResult().getType().template cast<ShapedType>();
// We can fold arithmetic operation only of we can prove that we will not
// accidentally hide a broadcasting error.
auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty,
ShapedType result_ty) -> bool {
// Scalar identity is broadcastable to any operand shape, we only need to
// check that operand has the same shape as a result.
bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0;
if (scalar_identity) return operand_ty == result_ty;
// If identity is not a scalar, we must verify that all shapes are equal
// and statically known.
//
// TODO(ezhulenev): Fold if identity shape is statically know to be
// broadcastable to the operand shape.
return operand_ty == result_ty && identity_ty == result_ty &&
result_ty.hasStaticShape();
};
// Check that we have a constant operand on one side (candidate for identity).
const bool is_commutative =
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
auto lhs_attr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhs_attr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (!rhs_attr && !(is_commutative && lhs_attr)) return {};
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
// value zero.
const int identity =
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
std::is_same<OpT, RealDivOp>::value)
? 1
: 0;
Type element_ty = lhs_type.getElementType();
Attribute identity_attr;
if (auto ty = element_ty.template dyn_cast<FloatType>()) {
identity_attr = FloatAttr::get(ty, static_cast<double>(identity));
} else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity));
} else {
return {};
}
// Fold: Op(Operand, Identity) -> Operand.
if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) {
if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr)
return arithmetic_op.x();
}
// Fold: Op(Identity, Operand) -> Operand for commutative operations.
if (lhs_attr && is_commutative &&
is_valid_broadcasting(rhs_type, lhs_type, result_type)) {
if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr)
return arithmetic_op.y();
}
return {};
}
} // namespace
// Verifies an reduction op's `input` and reduction `dims`.
static LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
Location loc) {
auto dims_type = dims.getType().dyn_cast<RankedTensorType>();
if (!dims_type) return success();
if (dims_type.getRank() > 1)
return emitError(loc, "dimensions can only be 0D or 1D tensor");
auto input_type = input.getType().dyn_cast<RankedTensorType>();
if (!input_type) return success();
int64_t rank = input_type.getRank();
DenseIntElementsAttr dims_attr;
if (!matchPattern(dims, m_Constant(&dims_attr))) return success();
for (const auto &dim_pair : llvm::enumerate(dims_attr)) {
int64_t cur_dim = dim_pair.value().getSExtValue();
if (cur_dim < -rank || cur_dim >= rank)
return emitError(loc)
<< dim_pair.index() << "-th dimension should be in the range of [-"
<< rank << ", " << rank << ")";
}
return success();
}
LogicalResult VerifyRegionResults(Operation *op, Region &region,
StringRef region_name) {
auto op_name = op->getName().getStringRef();
// verify that op outputs match yield inputs
YieldOp yield = cast<YieldOp>(region.front().getTerminator());
unsigned expected_num_results = op->getNumResults();
if (yield.getNumOperands() != expected_num_results)
return op->emitOpError()
<< region_name + " should have same number (" << expected_num_results
<< ") of results as " << op_name << " but has "
<< yield.getNumOperands() << " results";
for (int idx : llvm::seq<int>(0, expected_num_results)) {
auto op_result_type = op->getResult(idx).getType().cast<TensorType>();
auto region_result_type =
yield.getOperand(idx).getType().cast<TensorType>();
if (!AreCastCompatible({region_result_type, op_result_type}))
return op->emitError(llvm::formatv(
"{0} result type {1} is incompatible with {2} "
"result type {3} at index {4}",
region_name, region_result_type, op_name, op_result_type, idx));
}
return success();
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
namespace mlir {
namespace TF {
#define GET_OP_FWD_DEFINES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h.inc"
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
#include <algorithm>
#include <cstdint>
#include <functional>
#include <limits>
#include <numeric>
#include <string>
#include <tuple>
#include <type_traits>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
namespace mlir {
namespace TF {
namespace {
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc"
} // namespace TF
} // namespace mlir

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
namespace mlir {
namespace TF {
#define GET_OP_FWD_DEFINES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h.inc"
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_