diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 14d7faecdca..c5e2b089c0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -88,6 +88,7 @@ gentbl( cc_library( name = "tensorflow_op_interfaces", srcs = [ + "ir/tf_op_interfaces.cc", "ir/tf_op_interfaces.cc.inc", "ir/tf_op_interfaces.h.inc", "ir/tf_verifiers.cc", @@ -105,15 +106,67 @@ cc_library( ) gentbl( - name = "tensorflow_ops_inc_gen", + name = "tensorflow_all_ops_inc_gen", tbl_outs = [ ( "-gen-op-decls", - "ir/tf_ops.h.inc", + "ir/tf_all_ops.h.inc", ), ( "-gen-op-defs", - "ir/tf_ops.cc.inc", + "ir/tf_all_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tf_ops.td", + td_srcs = [ + ":tensorflow_ops_td_files", + ], +) + +# We only shard tf_op on name for build performance reasons. +tf_ops_category_list = [ + { + "name": "ops_a_m", + "include": "tf.[A-M].*$$", + }, + { + "name": "ops_n_z", + "include": "tf.[N-Z].*$$", + }, +] + +[[ + gentbl( + name = "tensorflow_" + target["name"] + "_inc_gen", + tbl_outs = [ + ( + "-gen-op-decls -op-include-regex='" + target["include"] + "'", + "ir/tf_" + target["name"] + ".h.inc", + ), + ( + "-gen-op-defs -op-include-regex='" + target["include"] + "'", + "ir/tf_" + target["name"] + ".cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tf_ops.td", + td_srcs = [ + ":tensorflow_ops_td_files", + ], + ), +] for target in tf_ops_category_list] + +gentbl( + name = "tensorflow_remaining_ops_inc_gen", + tbl_outs = [ + ( + "-gen-op-decls -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ", + "ir/tf_remaining_ops.h.inc", + ), + ( + "-gen-op-defs -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ", + "ir/tf_remaining_ops.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", @@ -179,7 +232,7 @@ gentbl( name = "tensorflow_device_ops_inc_gen", tbl_outs = [ ( - "-gen-op-decls", + "-gen-op-decls ", "ir/tf_device.h.inc", ), ( @@ -284,24 +337,67 @@ cc_library( ], ) +[[ + cc_library( + name = "tensorflow_" + target["name"], + srcs = [ + "ir/tf_ops.h", + "ir/tf_remaining_ops.h", + "ir/tf_" + target["name"] + ".cc", + "ir/tf_" + target["name"] + ".cc.inc", + ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], + hdrs = [ + ], + textual_hdrs = [ + "ir/tf_all_ops.h.inc", + "ir/tf_ops_helpers.inc", + "ir/tf_remaining_ops.h.inc", + ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], + deps = [ + ":tensorflow_attributes", + ":tensorflow_canonicalize_inc_gen", + ":tensorflow_op_interfaces", + ":tensorflow_op_interfaces_inc_gen", + ":tensorflow_side_effects", + ":tensorflow_structs", + ":tensorflow_traits", + ":tensorflow_types", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ] + [":tensorflow_" + target["name"] + "_inc_gen"], + ), +] for target in tf_ops_category_list] + cc_library( - name = "tensorflow_ops", + name = "tensorflow_remaining_ops", srcs = [ - "ir/tf_ops.cc", - "ir/tf_ops.cc.inc", "ir/tf_ops.h", - ], + "ir/tf_remaining_ops.h", + "ir/tf_remaining_ops.cc", + ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], hdrs = [ ], textual_hdrs = [ - "ir/tf_ops.h.inc", - ], + "ir/tf_all_ops.h.inc", + "ir/tf_ops_helpers.inc", + "ir/tf_remaining_ops.h.inc", + ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", ":tensorflow_op_interfaces", ":tensorflow_op_interfaces_inc_gen", - ":tensorflow_ops_inc_gen", + ":tensorflow_remaining_ops_inc_gen", ":tensorflow_side_effects", ":tensorflow_structs", ":tensorflow_traits", @@ -321,6 +417,43 @@ cc_library( ], ) +cc_library( + name = "tensorflow_ops", + srcs = [ + "ir/tf_ops.cc", + "ir/tf_ops.h", + ], + textual_hdrs = [ + "ir/tf_all_ops.h.inc", + "ir/tf_remaining_ops.h", + ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], + deps = [ + ":tensorflow_all_ops_inc_gen", + ":tensorflow_remaining_ops_inc_gen", + ":tensorflow_attributes", + ":tensorflow_canonicalize_inc_gen", + ":tensorflow_op_interfaces", + ":tensorflow_op_interfaces_inc_gen", + ":tensorflow_side_effects", + ":tensorflow_structs", + ":tensorflow_traits", + ":tensorflow_types", + ":tensorflow_remaining_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ] + [":tensorflow_" + target["name"] for target in tf_ops_category_list], +) + cc_library( name = "tensorflow_structs", srcs = [ @@ -393,12 +526,14 @@ cc_library( includes = ["include"], deps = [ ":error_util", + ":tensorflow_all_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", ":tensorflow_device_ops_inc_gen", ":tensorflow_executor_inc_gen", ":tensorflow_op_interfaces", ":tensorflow_ops", + ":tensorflow_side_effects", ":tensorflow_structs", ":tensorflow_traits", ":tensorflow_types", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 17424b54fc2..aaaf9c2fc5c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -332,7 +332,7 @@ class TF_DerivedOperandTypeListAttr : DerivedAttr< // This returns a list of shapes so it is used for variadic operands that // can have different shapes. class TF_DerivedOperandShapeListAttr : DerivedAttr< - "mlir::TF::OperandShapeRange", + "::mlir::TF::OperandShapeRange", "auto values = getODSOperands(" # idx # ");\n" "return {mlir::TF::OperandShapeIterator(values.begin()), " "mlir::TF::OperandShapeIterator(values.end())};", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc new file mode 100644 index 00000000000..3e99f1e162b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" + +namespace mlir::TF { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc" +} // namespace mlir::TF diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 98bc6b3089a..61935153c18 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -69,4483 +69,6 @@ limitations under the License. namespace mlir { namespace TF { -// Propagates underscore and device attributes from src to dst. -// TODO(b/158769932): This should be a general feature instead post some policy -// discussion. -static void PropagateAttributes(Operation *src, Operation *dst) { - auto device = mlir::Identifier::get("device", src->getContext()); - for (auto named_attr : src->getAttrs()) { - if (*named_attr.first.begin() == '_' || named_attr.first == device) - dst->setAttr(named_attr.first, named_attr.second); - } -} - -//===----------------------------------------------------------------------===// -// TF op helper functions -//===----------------------------------------------------------------------===// - -// Returns the RankedTensorType for the given operand. TensorFlow constant ops -// may have non-static shape because the shape is not propagated during constant -// folding. If the defining op for the given operand is a constant op, this -// routine uses the constant op's attribute to get the actual shape. -static RankedTensorType GetRankedTensorTypeForOperand(Value operand) { - DenseElementsAttr attr; - if (matchPattern(operand, m_Constant(&attr))) { - return attr.getType().dyn_cast(); - } - return operand.getType().dyn_cast(); -} - -// Returns true if the given `value` is of ranked float tensor type with the -// given `rank`. -static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { - return type && type.getRank() == rank && - type.getElementType().isa(); -} - -// Returns true if the given `value` has the specified rank or has unranked -// type. -static inline bool IsOfRankOrUnranked(Value value, int64_t rank) { - RankedTensorType type = GetRankedTensorTypeForOperand(value); - return !type || type.getRank() == rank; -} - -// Returns true if the given `value` has at least the specified rank or has -// unranked type. -static inline bool HasRankAtLeast(Value value, int64_t rank) { - RankedTensorType type = GetRankedTensorTypeForOperand(value); - return !type || type.getRank() >= rank; -} - -// Returns true if the given `value` has at most the specified rank or has -// unranked type. -static inline bool HasRankAtMost(Value value, int64_t rank) { - RankedTensorType type = GetRankedTensorTypeForOperand(value); - return !type || type.getRank() <= rank; -} - -static bool IsUnknownDimOrRank(int64_t dim_or_rank) { - return dim_or_rank == -1; -} - -// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If -// `incompatible_shape_error` is true, reports error if `x` and `y` has -// incompatible shapes. Otherwise, returns a tensor type with unknown rank. -static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, - Value y, BoolAttr incompatible_shape_error) { - auto result_type = - OpTrait::util::getBroadcastedType(x.getType(), y.getType()); - if (!result_type) { - if (incompatible_shape_error.getValue()) { - mlir::emitError(loc, "non-broadcastable operands"); - } else { - return UnrankedTensorType::get(builder->getI1Type()); - } - } - - auto ranked_type = result_type.dyn_cast(); - if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type()); - - return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type()); -} - -// Returns dimension index for the given TensorFlow axis that supports negative -// indexing. -static int64_t GetDimForAxis(int64_t axis, int64_t rank) { - return axis >= 0 ? axis : axis + rank; -} - -// Infers output type for reduction ops such as SumOp, MaxOp etc. -// TODO(b/e667204a): Move this logic to shape inference once it supports custom -// inference functions. -static Type InferReductionOpType(Value input, Value reduction_indices, - BoolAttr keep_dims, Builder *builder) { - Type input_ty = input.getType(); - Type element_ty = getElementTypeOrSelf(input_ty); - - // Output type is unranked if input type is not ranked. - auto ranked_ty = input_ty.dyn_cast(); - if (!ranked_ty) return UnrankedTensorType::get(element_ty); - int64_t rank = ranked_ty.getRank(); - - DenseIntElementsAttr indices; - if (!matchPattern(reduction_indices, m_Constant(&indices))) { - // Output type is unranked if reduction indices are not constant and reduced - // dimensions are not kept. - if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty); - - // Otherwise, output type has same rank as the input. - return RankedTensorType::get(SmallVector(rank, -1), element_ty); - } - - int64_t num_reduce_dim = 0; - llvm::SmallVector is_reduce_dim(rank, false); - for (const APInt &index : indices.getValues()) { - int64_t dim = GetDimForAxis(index.getSExtValue(), rank); - // Invalid input. - if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty); - - if (!is_reduce_dim[dim]) { - is_reduce_dim[dim] = true; - num_reduce_dim++; - } - } - - ArrayRef shape = ranked_ty.getShape(); - SmallVector out_shape; - out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim)); - for (int64_t i = 0; i < rank; ++i) { - if (!is_reduce_dim[i]) - out_shape.push_back(shape[i]); - else if (keep_dims.getValue()) - out_shape.push_back(1); - } - return RankedTensorType::get(out_shape, element_ty); -} - -// Verifies that the given types are cast compatible. If not, emits appropriate -// error for the given op. If mask_one_dim is set to true, then the types are -// allowed to have one mismatching dimension. Masking one of the dimensions is -// useful for ops like Concat that requires all ranked inputs to have the same -// rank and match dimension sizes for all but one of the dimensions. -static LogicalResult VerifyTypesCompatibility( - Operation::operand_type_range types, bool mask_one_dim, Operation *op) { - constexpr int64_t kUninitialized = -1; - int64_t common_rank = kUninitialized; - llvm::SmallVector common_dims; - int64_t dim_to_mask = kUninitialized; - - // Initialize common_rank with rank of the first ranked type and verify that - // following ranked types have the same rank. - // Similarly, initialize each of the dimensions with the first type that has - // the dimension size available and verify that all following types have the - // same size for the dimension. However, if mask_one_dim is true, note down - // the dimension index on the first mismatch and ignore dimension at that - // index in following types. - for (Type ty : types) { - RankedTensorType ranked_ty = ty.dyn_cast(); - if (!ranked_ty) continue; - - int64_t rank = ranked_ty.getRank(); - if (common_rank == kUninitialized) { - common_rank = rank; - common_dims.resize(common_rank, kUninitialized); - } else if (common_rank != rank) { - return op->emitError() - << "operand type " << ranked_ty - << " is not compatible with preceding operands; expected rank: " - << common_rank; - } - - for (int64_t i = 0, e = common_rank; i != e; i++) { - if (i == dim_to_mask) continue; - - int64_t dim = ranked_ty.getDimSize(i); - if (dim == kUninitialized) continue; - - int64_t &common_dim = common_dims[i]; - if (common_dim == kUninitialized) { - common_dim = dim; - } else if (common_dim != dim) { - // If mask_one_dim is true, do not emit an error if this is the only - // dimension with mismatches. Note down the dimension to mask it from - // the following types. - if (mask_one_dim && dim_to_mask == kUninitialized) { - dim_to_mask = i; - continue; - } - - return op->emitError() << "operand type " << ranked_ty - << " is not compatible with preceding operands; " - "expected dimension at index " - << i << ": " << common_dim; - } - } - } - return success(); -} - -// This is a helper for the Select to SelectV2 canonicalization. The `data` rank -// refers to the rank of `t`/`e` (these two inputs have equal rank; this is -// checked in the verifier). -// -// In most cases, the predicate for Select can be used directly as the predicate -// for SelectV2. However, there is one case that varies, which is when the -// predicate is a tensor and the data is multidimensional. In this case, Select -// op semantics dictate that the predicate tensor length must match the size of -// the first data dimension. This varies from normal broadcasting semantics -// (which are used in SelectV2), so we must reshape the tensor in this case to -// be compatible. -static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, - Value cond, int data_rank) { - auto cond_tensor = cond.getType().cast(); - // Reshape is only needed in the case that the cond rank is 1 (i.e. it is - // a vector) AND t/e rank is > 1. - if (cond_tensor.getRank() != 1 || data_rank <= 1) { - // No reshape necessary. Leave cond as it is. - return cond; - } - - // This is the case where a reshape is needed. We want to construct the - // shape [x,1,...1], where x is the value in the pred tensor and the - // length of the shape is equal to data_rank. - SmallVector shape(data_rank, 1); - shape[0] = cond_tensor.getShape().front(); - auto new_shape_type = - RankedTensorType::get({data_rank}, builder->getIntegerType(64)); - auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); - auto new_shape = builder->create(loc, shape_attr); - return builder->create(loc, cond, new_shape); -} - -//===----------------------------------------------------------------------===// -// Helper functions detect device capabilities from RuntimeDevices. -//===----------------------------------------------------------------------===// - -namespace { -using DeviceNameUtils = ::tensorflow::DeviceNameUtils; -using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName; - -bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) { - return device.type == ::tensorflow::DEVICE_GPU; -} - -} // namespace - -// Returns true if at least one GPU device is available at runtime. -bool CanUseGpuDevice(const RuntimeDevices &devices) { - return llvm::any_of(devices.device_names(), IsGpuDevice); -} - -// Returns true if all of the GPUs available at runtime support TensorCores -// (NVIDIA compute capability >= 7.0). -bool CanUseTensorCores(const RuntimeDevices &devices) { - auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) { - auto md = devices.GetGpuDeviceMetadata(device); - return md ? md->cc_major().getInt() >= 7 : false; - }; - return llvm::all_of( - llvm::make_filter_range(devices.device_names(), IsGpuDevice), - has_tensor_cores); -} - -// Returns true if operation does not have explicit device placement that would -// prevent it from running on GPU device. -bool CanUseGpuDevice(Operation *op) { - auto device_attr = op->getAttrOfType("device"); - if (!device_attr || device_attr.getValue().empty()) return true; - - DeviceNameUtils::ParsedName device; - if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device)) - return false; - - // We can't use GPU if operation explicitly placed on non-GPU device. - return !device.has_type || device.type == ::tensorflow::DEVICE_GPU; -} - -//===----------------------------------------------------------------------===// -// TF op helper functions to work with layout transformation. -//===----------------------------------------------------------------------===// - -SmallVector ReversePermutation(ArrayRef permutation) { - SmallVector reverse(permutation.size()); - for (size_t i = 0; i < permutation.size(); ++i) { - reverse[permutation[i]] = i; - } - return reverse; -} - -SmallVector GetDataFormatPermutation(StringRef from, StringRef to) { - if (from == "NHWC" && to == "NCHW") { - return {0, 3, 1, 2}; - } else if (from == "NCHW" && to == "NHWC") { - return {0, 2, 3, 1}; - } else { - return {}; - } -} - -// Shuffle elements in the `attr` according to the permutation. Optional -// `inner_size` allows to shuffle array attributes created from rank 2 tensors -// on outer dimension only. -ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef permutation, - int inner_size = 1) { - if (attr.size() == 0) return attr; - - assert(attr.size() % inner_size == 0); - assert(attr.size() / inner_size == permutation.size()); - - SmallVector values{attr.begin(), attr.end()}; - SmallVector shuffled(values.size()); - - for (size_t i = 0; i < permutation.size(); ++i) { - for (size_t j = 0; j < inner_size; ++j) { - shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j]; - } - } - - return ArrayAttr::get(shuffled, attr.getContext()); -} - -// Shuffle ranked tensor dimensions according to the permutation. -Type ShuffleRankedTensorType(Type type, ArrayRef permutation) { - if (auto ranked_type = type.dyn_cast()) { - ArrayRef shape = ranked_type.getShape(); - assert(permutation.size() == shape.size()); - - SmallVector new_shape(permutation.size()); - for (size_t i = 0; i < permutation.size(); ++i) - new_shape[i] = shape[permutation[i]]; - - return RankedTensorType::get(new_shape, ranked_type.getElementType()); - } - - return type; -} - -static bool AreCancellablePermutations(DenseIntElementsAttr perm0, - DenseIntElementsAttr perm1) { - if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false; - if (perm0.getNumElements() != perm1.getNumElements()) return false; - - SmallVector perm0_values; - for (const auto &value : perm0.getIntValues()) - perm0_values.push_back(value.getSExtValue()); - - SmallVector perm1_values; - for (const auto &value : perm1.getIntValues()) - perm1_values.push_back(value.getSExtValue()); - - for (int i = 0; i < perm0_values.size(); ++i) { - if (perm0_values[perm1_values[i]] != i) return false; - } - - return true; -} - -// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for -// layout sensitive operations that do not have any additional layout dependent -// attributes besides `data_format` string. -template -LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { - auto perm = GetDataFormatPermutation(op->data_format(), data_format); - if (perm.empty()) return failure(); - - // Update data format attribute. - op->setAttr("data_format", StringAttr::get(data_format, op->getContext())); - - // Update types for all layout sensitive results. - auto layout_sensitive = cast(op->getOperation()); - for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) { - OpResult result = op->getOperation()->getResult(idx); - result.setType(ShuffleRankedTensorType(result.getType(), perm)); - } - - return success(); -} - -// Default implementation for folding operand transpose into the operation. -// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`. -template -LogicalResult FoldOperandsPermutation( - ArrayRef permutation, Op *op, - ArrayRef> shuffle_attrs = {}) { - MLIRContext *context = op->template getParentOfType().getContext(); - - // We only support NHWC <-> NCHW permutations. - static constexpr std::array kNchwToNhwc = {0, 2, 3, 1}; - static constexpr std::array kNhwcToNchw = {0, 3, 1, 2}; - - // Operation data format after folding `permutation`. - StringRef target_data_format = [&]() -> StringRef { - if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) { - return "NCHW"; // cancel NCHW->NHWC operand permutation - } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) { - return "NHWC"; // cancel NHWC->NCHW operand permutation - } else { - return ""; - } - }(); - if (target_data_format.empty()) return failure(); - - // To fold operand `permutation` into the `op` we need shuffle all layout - // dependent attributes and types with a reverse permutation, and change - // operation data format to `target_data_format`. - // - // Example: - // %1 = SomeOp(...) {data_format = NHWC} - // %2 = Transpose(%1) {permutation = NHWC->NCHW} - // %3 = Op(%2) {data_format = NCHW} - // - // To bypass %2 we have to change data format to shuffle data format from NCHW - // to NHWC, which is the reverse of operand permutation (function argument). - auto reverse_permutation = - GetDataFormatPermutation(op->data_format(), target_data_format); - if (reverse_permutation.empty()) return failure(); - - op->setAttr("data_format", StringAttr::get(target_data_format, context)); - - for (auto pair : shuffle_attrs) { - StringRef attr_name = pair.first; - ArrayAttr attr_value = pair.second; - op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation)); - } - - auto fold = cast(op->getOperation()); - for (unsigned idx : fold.GetLayoutDependentResults()) { - OpResult result = op->getOperation()->getResult(idx); - result.setType( - ShuffleRankedTensorType(result.getType(), reverse_permutation)); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// Rewrite Pattern for removing trivial Arithmetic op. -//===----------------------------------------------------------------------===// - -namespace { -// Fold Arithmetic Op if one of the operands is a constant known to be an -// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if -// known identity value is either lhs or rhs. -template < - typename OpT, - typename std::enable_if::value>::type * = nullptr> -OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, - ArrayRef operands) { - auto lhs_type = arithmetic_op.x().getType().template cast(); - auto rhs_type = arithmetic_op.y().getType().template cast(); - auto result_type = - arithmetic_op.getResult().getType().template cast(); - - // We can fold arithmetic operation only of we can prove that we will not - // accidentally hide a broadcasting error. - auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty, - ShapedType result_ty) -> bool { - // Scalar identity is broadcastable to any operand shape, we only need to - // check that operand has the same shape as a result. - bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; - if (scalar_identity) return operand_ty == result_ty; - - // If identity is not a scalar, we must verify that all shapes are equal - // and statically known. - // - // TODO(ezhulenev): Fold if identity shape is statically know to be - // broadcastable to the operand shape. - return operand_ty == result_ty && identity_ty == result_ty && - result_ty.hasStaticShape(); - }; - - // Check that we have a constant operand on one side (candidate for identity). - const bool is_commutative = - (std::is_same::value || std::is_same::value); - auto lhs_attr = operands[0].dyn_cast_or_null(); - auto rhs_attr = operands[1].dyn_cast_or_null(); - if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; - - // Mul and Div ops have identity value one while AddV2 and SubOp have identity - // value zero. - const int identity = - (std::is_same::value || std::is_same::value || - std::is_same::value) - ? 1 - : 0; - - Type element_ty = lhs_type.getElementType(); - Attribute identity_attr; - if (auto ty = element_ty.template dyn_cast()) { - identity_attr = FloatAttr::get(ty, static_cast(identity)); - } else if (auto ty = element_ty.template dyn_cast()) { - identity_attr = IntegerAttr::get(ty, static_cast(identity)); - } else { - return {}; - } - - // Fold: Op(Operand, Identity) -> Operand. - if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { - if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr) - return arithmetic_op.x(); - } - - // Fold: Op(Identity, Operand) -> Operand for commutative operations. - if (lhs_attr && is_commutative && - is_valid_broadcasting(rhs_type, lhs_type, result_type)) { - if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr) - return arithmetic_op.y(); - } - - return {}; -} -} // namespace - -namespace { -#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" -} // namespace - -//===----------------------------------------------------------------------===// -// AddOp -//===----------------------------------------------------------------------===// - -void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// AddNOp -//===----------------------------------------------------------------------===// - -OpFoldResult AddNOp::fold(ArrayRef operands) { - if (operands.size() == 1) return *inputs().begin(); - return {}; -} - -//===----------------------------------------------------------------------===// -// AddV2Op -//===----------------------------------------------------------------------===// - -void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult AddV2Op::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// AllOp -//===----------------------------------------------------------------------===// - -// Verifies an reduction op's `input` and reduction `dims`. -static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, - Location loc) { - auto dims_type = dims.getType().dyn_cast(); - if (!dims_type) return success(); - if (dims_type.getRank() > 1) - return emitError(loc, "dimensions can only be 0D or 1D tensor"); - - auto input_type = input.getType().dyn_cast(); - if (!input_type) return success(); - int64_t rank = input_type.getRank(); - - DenseIntElementsAttr dims_attr; - if (!matchPattern(dims, m_Constant(&dims_attr))) return success(); - for (const auto &dim_pair : llvm::enumerate(dims_attr)) { - int64_t cur_dim = dim_pair.value().getSExtValue(); - if (cur_dim < -rank || cur_dim >= rank) - return emitError(loc) - << dim_pair.index() << "-th dimension should be in the range of [-" - << rank << ", " << rank << ")"; - } - - return success(); -} - -static LogicalResult Verify(AllOp op) { - return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), - op.getLoc()); -} - -//===----------------------------------------------------------------------===// -// AnyOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(AnyOp op) { - return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), - op.getLoc()); -} - -//===----------------------------------------------------------------------===// -// AssertOp -//===----------------------------------------------------------------------===// - -namespace { - -// Removes Assert with constant true predicate. -struct AssertWithTrue : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AssertOp op, - PatternRewriter &rewriter) const override { - ElementsAttr cst; - if (matchPattern(op.condition(), m_Constant(&cst))) { - if (cst.getValue({}).getValue()) { - rewriter.eraseOp(op); - return success(); - } - } - return failure(); - } -}; -} // namespace - -void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BatchMatMulOp -//===----------------------------------------------------------------------===// - -void BatchMatMulOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BatchMatMulV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(BatchMatMulV2Op op) { - if (!HasRankAtLeast(op.x(), 2)) { - return op.emitOpError("requires lhs operand to have rank at least two"); - } - if (!HasRankAtLeast(op.y(), 2)) { - return op.emitOpError("requires rhs operand to have rank at least two"); - } - return success(); -} - -void BatchMatMulV2Op::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BatchToSpaceOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(BatchToSpaceOp op) { - // Op already has a constraint that block_size >= 2. - int64_t block_size = op.block_size().getSExtValue(); - - llvm::SmallVector input_shape(4, ShapedType::kDynamicSize); - auto input_type = op.input().getType().cast(); - if (input_type.hasRank()) { - if (input_type.getRank() != 4) - return op.emitOpError() - << "requires input to be a 4D tensor, but got " << input_type; - - int64_t input_batch = input_type.getDimSize(0); - if (input_batch != ShapedType::kDynamicSize && - input_batch % (block_size * block_size) != 0) { - return op.emitOpError() - << "requires input batch (dimension 0) to be evenly divisible " - "by (block_size * block_size), but got input batch " - << input_batch << " and block_size " << block_size; - } - - input_shape.assign(input_type.getShape().begin(), - input_type.getShape().end()); - } - - auto crops_type = op.crops().getType().cast(); - if (crops_type.hasRank()) { - if (crops_type.getRank() != 2) - return op.emitOpError() - << "requires crops to be a 2D tensor, but got " << crops_type; - - auto dim_of_size = [&](int64_t dim, int64_t size) { - if (crops_type.isDynamicDim(dim)) return true; - return crops_type.getDimSize(dim) == size; - }; - if (!dim_of_size(0, 2) || !dim_of_size(1, 2)) - return op.emitOpError() - << "requires crops to be a tensor<2x2>, but got " << crops_type; - } - - DenseIntElementsAttr crops_attr; - // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]], - // and flattened as [crop_top, crop_bottom, crop_left, crop_right] - llvm::SmallVector crops_values; - if (matchPattern(op.crops(), m_Constant(&crops_attr))) { - assert(crops_attr.getNumElements() == 4 && - "tf.BatchToSpace crops must have 4 elements"); - - auto crops_range = crops_attr.getIntValues(); - for (const auto &crops_value : crops_range) { - int64_t crops_value_int = crops_value.getSExtValue(); - if (crops_value_int < 0) - return op.emitOpError() - << "requires all crop values to be nonnegative, but got " - << crops_attr; - - crops_values.push_back(crops_value_int); - } - } - - auto output_type = op.output().getType().cast(); - if (output_type.hasRank()) { - if (output_type.getRank() != 4) - return op.emitOpError() - << "requires output to be a 4D tensor, but got " << output_type; - - auto static_dims = [](int64_t dim_a, int64_t dim_b) { - return dim_a != ShapedType::kDynamicSize && - dim_b != ShapedType::kDynamicSize; - }; - - auto output_shape = output_type.getShape(); - - // output batch = input batch / (block_size * block_size). - int64_t input_batch = input_shape[0]; - int64_t output_batch = output_shape[0]; - if (static_dims(input_batch, output_batch) && - (output_batch * block_size * block_size) != input_batch) - return op.emitOpError() - << "requires output batch (dimension 0) to be equal to input " - "batch (dimension 0) / (block_size * block_size), but got " - "output batch " - << output_batch << ", input batch " << input_batch - << ", and block_size " << block_size; - - auto check_spatial_dim = [&](int64_t spatial_dim_index, - llvm::StringRef dim_name, - llvm::StringRef crop_a_name, - llvm::StringRef crop_b_name) -> LogicalResult { - int64_t input_dim = input_shape[spatial_dim_index]; - int64_t output_dim = output_shape[spatial_dim_index]; - if (!static_dims(input_dim, output_dim)) return success(); - - int64_t input_dim_pad = input_dim * block_size; - // If crops are unknown, the maximum output spatial dim size is input - // spatial dim size * block_size, as crops can be minimum 0. - if (crops_values.empty() && output_dim > input_dim * block_size) - return op.emitOpError() - << "requires output " << dim_name << " (dimension " - << spatial_dim_index << ") to be less than or equal to input " - << dim_name << " (dimension " << spatial_dim_index - << ") * block_size, but got output " << dim_name << " " - << output_dim << ", input " << dim_name << " " << input_dim - << ", and block_size " << block_size; - - if (!crops_values.empty()) { - // output spatial dim = input spatial dim * block_size - crops. - int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)]; - int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1]; - if (output_dim != input_dim_pad - crop_a - crop_b) - return op.emitOpError() - << "requires output " << dim_name << " (dimension " - << spatial_dim_index << ") to be equal to input " << dim_name - << " (dimension " << spatial_dim_index << ") * block_size - " - << crop_a_name << " - " << crop_b_name << ", but got output " - << dim_name << " " << output_dim << ", input " << dim_name - << " " << input_dim << ", " << crop_a_name << " " << crop_a - << ", " << crop_b_name << " " << crop_b << ", and block_size " - << block_size; - } - - return success(); - }; - - if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) || - failed(check_spatial_dim(2, "width", "crop_left", "crop_right"))) - return failure(); - - int64_t input_depth = input_shape[3]; - int64_t output_depth = output_shape[3]; - if (static_dims(input_depth, output_depth) && output_depth != input_depth) - return op.emitOpError() - << "requires output depth (dimension 3) to be equal to input " - "depth (dimension 3), but got output depth " - << output_depth << " and input depth " << input_depth; - } - - return success(); -} - -void BatchToSpaceOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BiasAddOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// * the value and bias operands have valid ranks or are unranked. -// * Channel dimension of the value operand and length of bias matches if they -// are not unknown. -// -static LogicalResult Verify(BiasAddOp op) { - StringRef format = op.data_format(); - if (format == "NHWC") { - if (!HasRankAtLeast(op.value(), 2)) - return op.emitOpError( - "requires value operand to have rank at least two with `NHWC` data " - "format"); - } else { - // Op definition requires data_format to be either NHWC or NCHW. - DCHECK_EQ(format.str(), "NCHW"); - if (!HasRankAtLeast(op.value(), 3)) - return op.emitOpError( - "requires value operand to have rank at least three with `NCHW` data " - "format"); - } - - if (!IsOfRankOrUnranked(op.bias(), 1)) - return op.emitOpError("requires bias operand to have rank exactly one"); - - RankedTensorType value_ty = op.value().getType().dyn_cast(); - RankedTensorType bias_ty = op.bias().getType().dyn_cast(); - if (!bias_ty || !value_ty) return success(); - - // TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute - // dimension indices based on format. - int64_t feature_dim_idx = format == "NHWC" ? value_ty.getRank() - 1 : 1; - int64_t feature_dim = value_ty.getDimSize(feature_dim_idx); - int64_t bias_len = bias_ty.getDimSize(0); - if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) { - return op.emitOpError() - << "requires channel dimension and feature dimension to match; " - "found " - << feature_dim << " and " << bias_len << ", respectively"; - } - return success(); -} - -//===----------------------------------------------------------------------===// -// BiasAddGradOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// * the out_backprop operands have valid ranks or are unranked. -// -static LogicalResult Verify(BiasAddGradOp op) { - StringRef format = op.data_format(); - if (format == "NHWC") { - if (!HasRankAtLeast(op.out_backprop(), 2)) - return op.emitOpError( - "requires out_backprop operand to have rank at least two with `NHWC` " - "data format"); - } else { - // Op definition requires data_format to be either NHWC or NCHW. - DCHECK_EQ(format.str(), "NCHW"); - if (!HasRankAtLeast(op.out_backprop(), 3)) - return op.emitOpError( - "requires out_backprop operand to have rank at least three with " - "`NCHW` data format"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// BiasAddV1Op -//===----------------------------------------------------------------------===// - -void BiasAddV1Op::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BitcastOp -//===----------------------------------------------------------------------===// - -void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BroadcastToOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(BroadcastToOp op) { - // TODO(antiagainst): check that - // * The 'shape' input is an 1-D int tensor. - // * Each dimension pair of the source and target shapes are either equal - // or one of them is one. - return success(); -} - -//===----------------------------------------------------------------------===// -// CaseOp -//===----------------------------------------------------------------------===// - -class FoldConstantCaseOp : public OpRewritePattern { - public: - explicit FoldConstantCaseOp(MLIRContext *context) - : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(TF::CaseOp op, - PatternRewriter &rewriter) const override; -}; - -LogicalResult FoldConstantCaseOp::matchAndRewrite( - TF::CaseOp op, PatternRewriter &rewriter) const { - // Extract the constant cond value. - DenseIntElementsAttr branch; - if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); - - // Only attempt to fold scalar valued case statements. - // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. - if (!branch.getType().cast().getShape().empty()) - return failure(); - - int index = *branch.getValues().begin(); - // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. - if (index >= op.branches().size()) return failure(); - - auto func = op.branches()[index].cast(); - auto empty = rewriter.getStringAttr(""); - auto call_op = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, - /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); - PropagateAttributes(op.getOperation(), call_op); - rewriter.replaceOp(op, call_op.getResults()); - return success(); -} - -void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// CastOp -//===----------------------------------------------------------------------===// - -OpFoldResult CastOp::fold(ArrayRef operands) { - // Cast with the same type is a no-op. - Value operand = getOperand(); - if (getType() == operand.getType()) return operand; - return {}; -} - -//===----------------------------------------------------------------------===// -// ConcatOp and ConcatV2Op -//===----------------------------------------------------------------------===// - -template ::value>::type * = nullptr> -static LogicalResult Verify(OpT op) { - // TODO(hinsu): Convert variadic length attributes to derived attributes. - Operation::operand_range values = op.values(); - - int axis_idx = std::is_same() ? 0 : 1; - Value axis = *op.getODSOperands(axis_idx).begin(); - if (!HasRankAtMost(axis, 1)) { - return op.emitOpError( - "requires axis to be of scalar type (or vector type for older " - "versions)"); - } - - return VerifyTypesCompatibility(values, - /*mask_one_dim=*/true, op.getOperation()); -} - -void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// ConcatOffsetOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(ConcatOffsetOp op) { - if (op.N() < 2) - return op.emitOpError() << "requires N to be at least 2, got " << op.N(); - - if (op.shape().size() != op.offset().size()) - return op.emitOpError() - << "requires sizes of shapes and offsets to be the same, got sizes " - << op.shape().size() << " and " << op.offset().size(); - - auto ranked_dim = op.concat_dim().getType().dyn_cast(); - if (ranked_dim && ranked_dim.getRank() != 0) - return op.emitOpError() - << "requires concat_dim to be a scalar, got tensor of rank " - << ranked_dim.getRank(); - - int64_t num_dims = -1; - for (auto shape_offset_idx : - llvm::enumerate(llvm::zip(op.shape(), op.offset()))) { - Value shape = std::get<0>(shape_offset_idx.value()); - Value offset = std::get<1>(shape_offset_idx.value()); - const size_t idx = shape_offset_idx.index(); - - if (failed(verifyCompatibleShape(shape.getType(), offset.getType()))) - return op.emitOpError() << "requires operand and result " << idx - << " to have compatible shapes"; - - auto ranked_shape = shape.getType().dyn_cast(); - if (!ranked_shape) continue; - - if (ranked_shape.getRank() != 1) - return op.emitOpError() << "requires shape tensor operand " << idx - << " to be of rank 1, got tensor of rank " - << ranked_shape.getRank(); - - if (!ranked_shape.hasStaticShape()) continue; - - int64_t ranked_shape_dim = ranked_shape.getDimSize(0); - if (num_dims == -1) - num_dims = ranked_shape_dim; - else if (ranked_shape_dim != num_dims) - return op.emitOpError() - << "requires shape tensor (rank 1) operand " << idx - << " to be of length " << num_dims - << ", got tensor (rank 1) of length " << ranked_shape_dim; - } - - return success(); -} - -LogicalResult ConcatOffsetOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - // ConcatOffset must have its first operand be concat_dim and at least two - // shape tensors in variadic shapes operand. - if (operands.size() < 3) return failure(); - - // Check concat_dim is a scalar. - auto concat_dim_attr = operands[0].dyn_cast_or_null(); - if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0) - return failure(); - - llvm::SmallVector shapes; - shapes.reserve(operands.size() - 1); - for (Attribute shape : llvm::drop_begin(operands, 1)) - if (auto shape_attr = shape.dyn_cast_or_null()) - shapes.push_back(shape_attr); - else - return failure(); - - // Check all shapes are vectors of the same length. - if (shapes.front().getType().getRank() != 1) return success(); - const int64_t num_dims = shapes.front().getNumElements(); - for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) - if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims) - return failure(); - - // Check concat_dim is within [-num_dims, num_dims). - int32_t concat_dim = (*concat_dim_attr.getValues().begin()); - if (concat_dim < 0) concat_dim += num_dims; - if (concat_dim >= num_dims || concat_dim < 0) return failure(); - - // Check all elements besides at concat_dim match across all shape tensors. - SmallVector shape0; - shape0.reserve(num_dims); - for (int32_t dim : shapes.front().getValues()) shape0.push_back(dim); - - for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) { - for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) { - if (dims_and_idx.index() == concat_dim) continue; - - if (std::get<0>(dims_and_idx.value()) != - std::get<1>(dims_and_idx.value()).getSExtValue()) - return failure(); - } - } - - // Compute an exclusive cumulative sum of elements at concat_dim. - results.reserve(shapes.size()); - SmallVector cumulative_sum(num_dims, 0); - RankedTensorType offset_type = - RankedTensorType::get({num_dims}, IntegerType::get(32, getContext())); - for (DenseIntElementsAttr shape : shapes) { - results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); - cumulative_sum[concat_dim] += shape.getValue(concat_dim); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// ConjOp -//===----------------------------------------------------------------------===// - -void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// ConstOp -//===----------------------------------------------------------------------===// - -OpFoldResult ConstOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); - - // Return the held attribute value. - return value(); -} - -// Builds a constant op with the specified attribute `value`. The result -// op's type is deduced from `value`; if `value` is of scalar type, -// wraps it up with a tensor type of empty shape. -// TODO(jpienaar): This one differs from the autogenerated one as it takes an -// attribute but always creates an ElementsAttr internally. -void ConstOp::build(OpBuilder &builder, OperationState &result, - Attribute value) { - ShapedType type; - if (auto elem_attr = value.dyn_cast()) { - return ConstOp::build(builder, result, elem_attr); - } else if (value.isa()) { - // All TensorFlow types must be tensor types. In the build() method, - // we want to provide more flexibility by allowing attributes of scalar - // types. But we need to wrap it up with ElementsAttr to construct - // valid TensorFlow constants. - type = RankedTensorType::get(/*shape=*/{}, value.getType()); - return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); - } - // TODO(jpienaar): support other TensorFlow specific types. - llvm_unreachable("unsupported attribute type for building tf.Const"); -} - -void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, - Attribute value) { - // Handle the case where the type and value are already tensors. - if (type.isa() && value.isa()) { - result.addTypes(type); - result.addAttribute("value", value); - return; - } - - // Otherwise, default to the attribute builder. - ConstOp::build(builder, result, value); - assert(type == result.types[0] && "type mismatch in construction"); -} - -LogicalResult ConstOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - auto value = attributes.get("value"); - if (!value) return emitOptionalError(location, "missing attribute 'value'"); - if (auto elem_attr = value.dyn_cast()) { - inferredReturnTypes.assign({elem_attr.getType()}); - return success(); - } - return emitOptionalError(location, - "attribute 'value' failed to satisfy constraint: " - "constant vector/tensor"); -} - -//===----------------------------------------------------------------------===// -// Conv2DOp and Conv3DOp -//===----------------------------------------------------------------------===// - -template -static LogicalResult VerifyConvOpAttributes(OpT op, int num_dims) { - if (!IsOfRankOrUnranked(op.getResult(), num_dims)) - return op.emitOpError() - << "requires result to be " << num_dims << "D tensor"; - - auto is_not_positive = [](Attribute val) { - return val.cast().getValue().getSExtValue() <= 0; - }; - - int64_t strides_size = op.strides().size(); - if (strides_size != num_dims) - return op.emitOpError() << "requires strides attribute length to be " - << num_dims << "; actual length " << strides_size; - if (llvm::any_of(op.strides().getValue(), is_not_positive)) - return op.emitOpError("requires positive strides"); - - int64_t dilations_size = op.strides().size(); - if (op.dilations().size() != num_dims) - return op.emitOpError() << "requires dilations attribute length to be " - << num_dims << "; actual length " << dilations_size; - if (llvm::any_of(op.dilations().getValue(), is_not_positive)) - return op.emitOpError("requires positive dilations"); - - return success(); -} - -// Verifies that, -// * Ranks of operands and result are valid -// * Number of input channels is divisible by the number of filter input -// channels -// * Length of explicit_paddings attribute is valid and has non negative -// elements -// * strides and dilations attributes have positive elements -template ::value>::type * = nullptr> -static LogicalResult Verify(OpT op) { - int num_spatial_dims = std::is_same() ? 2 : 3; - int num_dims = 2 + num_spatial_dims; - - if (!IsOfRankOrUnranked(op.input(), num_dims) || - !IsOfRankOrUnranked(op.filter(), num_dims)) - return op.emitOpError() - << "requires operands to be " << num_dims << "D tensor"; - - // EXPLICIT padding mode and the associated attribute is limited to Conv2D. - // So, fetch attribute by string instead of the op.explicit_paddings() - // attribute getter. - if (op.padding() == "EXPLICIT") { - auto paddings = op.template getAttrOfType("explicit_paddings"); - if (!paddings) - return op.emitOpError() << "requires attribute 'explicit_paddings' with " - "'EXPLICIT' padding mode"; - - int64_t paddings_size = paddings.size(); - int64_t expected_size = 2 * num_dims; - - if (paddings_size != expected_size) - return op.emitOpError() - << "requires explicit_paddings attribute length to be " - << expected_size << "; actual length " << paddings_size; - - auto is_negative = [](Attribute val) { - return val.cast().getValue().getSExtValue() < 0; - }; - if (llvm::any_of(paddings.getValue(), is_negative)) - return op.emitOpError("requires non negative explicit paddings"); - } - - LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims); - if (failed(verify_result)) { - return verify_result; - } - - int64_t input_channels = -1; - if (auto ty = op.input().getType().template dyn_cast()) { - std::string data_format = op.data_format().str(); - tensorflow::TensorFormat format; - auto is_valid = FormatFromString(data_format, &format); - DCHECK(is_valid) << data_format; - int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format); - input_channels = ty.getDimSize(idx); - } - - int64_t filter_channels = -1; - if (auto ty = op.filter().getType().template dyn_cast()) { - int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( - num_dims, tensorflow::FORMAT_HWIO); - filter_channels = ty.getDimSize(idx); - } - - if (input_channels != -1 && filter_channels != -1 && - input_channels % filter_channels != 0) - return op.emitOpError() - << "requires the number of input channels to be divisible by the " - "number of filter input channels; found " - << input_channels << " and " << filter_channels << ", respectively"; - - return success(); -} - -LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { - auto perm = GetDataFormatPermutation(this->data_format(), data_format); - if (perm.empty()) return failure(); - - // Update data_format attribute and result types. - if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); - - // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); - - return success(); -} - -StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // Input must be a tensor. - auto input_ty = input().getType().dyn_cast(); - if (!input_ty) return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - const bool is_f16 = input_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // For f32/f16 data type decision depends on the filter size in spatial - // dimensions, for other data types we keep current data format. - if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16()) - return data_format(); - - // Keep current data format if filter rank is unknown or not equal to 4. - auto filter_ty = filter().getType().dyn_cast(); - if (!filter_ty || filter_ty.getRank() != 4) return data_format(); - - const int64_t d0 = filter_ty.getDimSize(0); - const int64_t d1 = filter_ty.getDimSize(1); - - auto all_ones = [](ArrayAttr arr) -> bool { - return llvm::all_of(arr, [](Attribute attr) -> bool { - return attr.cast().getInt() == 1; - }); - }; - - // Convolutions with 1x1 filter and with strides and dilations all ones, can - // be computed as a GEMM in NHWC data format, and can be up to ~2x times - // faster than convolution in NCHW. - const bool one_by_one = d0 == 1 && d1 == 1; - const bool trivial_strides = all_ones(strides()); - const bool trivial_dilations = all_ones(dilations()); - - // TODO(ezhulenev): This might lead to excessive transposes in the final IR, - // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1. - // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all - // users can efficiently use NHWC data format? - if (one_by_one && trivial_strides && trivial_dilations) { - return "NHWC"; - } - - // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because - // it's the fastest option on NVIDIA GPUs with cuDNN library support. - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// Conv2dBackpropFilterOp -//===----------------------------------------------------------------------===// - -LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { - StringRef src_data_format = this->data_format(); - - auto perm = GetDataFormatPermutation(src_data_format, data_format); - if (perm.empty()) return failure(); - - // Update data_format attribute and result types. - if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); - - // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); - - // Permute filter sizes operand. - OpBuilder builder(getOperation()); - auto filter_sizes_permuted = builder.create( - getLoc(), filter_sizes(), StringAttr::get(src_data_format, getContext()), - StringAttr::get(data_format, getContext())); - setOperand(1, filter_sizes_permuted); - - return success(); -} - -StringRef Conv2DBackpropFilterOp::GetOptimalLayout( - const RuntimeDevices &devices) { - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // Input must be a tensor. - auto input_ty = input().getType().dyn_cast(); - if (!input_ty) return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - const bool is_f16 = input_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // Otherwise always use "NCHW". - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// Conv2DBackpropInputOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(Conv2DBackpropInputOp op) { - int num_spatial_dims = 2; - int num_dims = 2 + num_spatial_dims; - - if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) || - !IsOfRankOrUnranked(op.filter(), num_dims)) - return op.emitOpError() - << "requires operands to be " << num_dims << "D tensor"; - - LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims); - if (failed(verify_result)) { - return verify_result; - } - - return success(); -} - -LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { - StringRef src_data_format = this->data_format(); - - auto perm = GetDataFormatPermutation(src_data_format, data_format); - if (perm.empty()) return failure(); - - // Update data_format attribute and result types. - if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); - - // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); - - // Permute input sizes operand. - OpBuilder builder(getOperation()); - auto input_sizes_permuted = builder.create( - getLoc(), input_sizes(), StringAttr::get(src_data_format, getContext()), - StringAttr::get(data_format, getContext())); - setOperand(0, input_sizes_permuted); - - return success(); -} - -StringRef Conv2DBackpropInputOp::GetOptimalLayout( - const RuntimeDevices &devices) { - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // Filter must be a tensor. - auto filter_ty = filter().getType().dyn_cast(); - if (!filter_ty) return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - const bool is_f16 = filter_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // Otherwise always use "NCHW". - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// DataFormatVecPermuteOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(DataFormatVecPermuteOp op) { - auto input_ty = op.x().getType().dyn_cast(); - if (!input_ty) return success(); - - int rank = input_ty.getRank(); - if (rank != 1 && rank != 2) - return op.emitOpError("requires input of rank 1 or 2"); - - if (rank == 1) { - int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) - return op.emitOpError("requires 1D input of size 4 or size 2"); - } - - if (rank == 2) { - int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4) - return op.emitOpError( - "requires first dimensions of 2D input to be of size 4"); - - int64_t dim1 = input_ty.getDimSize(1); - if (dim1 != ShapedType::kDynamicSize && dim1 != 2) - return op.emitOpError( - "requires second dimensions of 2D input to be of size 2"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// DivOp -//===----------------------------------------------------------------------===// - -void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult DivOp::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// DynamicStitchOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(DynamicStitchOp op) { - if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1"); - - if (RankedTensorType out_ty = op.getType().dyn_cast()) { - if (out_ty.getRank() == 0) { - return op.emitOpError("requires non scalar output"); - } - } - - llvm::SmallDenseSet index_values; - bool all_indices_const = true; - int32_t max_index = -1; - llvm::Optional> inferred_item_shape; - for (auto it : llvm::zip(op.indices(), op.data())) { - Value index = std::get<0>(it); - - DenseIntElementsAttr index_attr; - if (matchPattern(index, m_Constant(&index_attr))) { - for (int32_t index : index_attr.getValues()) { - if (index < 0) - return op.emitOpError() - << "requires non-negative index values; found " << index; - max_index = std::max(index, max_index); - index_values.insert(index); - } - } else { - all_indices_const = false; - } - - Value data = std::get<1>(it); - RankedTensorType index_ty = index.getType().dyn_cast(); - RankedTensorType data_ty = data.getType().dyn_cast(); - if (!index_ty || !data_ty) continue; - - int64_t index_rank = index_ty.getRank(); - ArrayRef data_shape = data_ty.getShape(); - ArrayRef index_shape = index_ty.getShape(); - if (failed(mlir::verifyCompatibleShape(index_shape, - data_shape.take_front(index_rank)))) - return op.emitOpError() << "requires shape of data with type " << data_ty - << " to have prefix matching with shape of the " - "corresponding index type " - << index_ty; - - ArrayRef item_shape = data_shape.drop_front(index_rank); - if (!inferred_item_shape) { - inferred_item_shape = llvm::to_vector<4>(item_shape); - continue; - } - - if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape))) - return op.emitOpError() << "has inconsistent shaped data and index " - "pairs; inferred item shapes [" - << llvm::makeArrayRef(*inferred_item_shape) - << "] and [" << item_shape << "] don't match"; - for (int i = 0, e = item_shape.size(); i < e; ++i) { - int64_t &inferred_dim = (*inferred_item_shape)[i]; - int64_t dim = item_shape[i]; - if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim; - } - } - - // If all indices are constants, then verify that they cover all indices in - // the range [0, max_index] and the output type is legal. - if (all_indices_const) { - for (int32_t i = 0; i <= max_index; i++) { - if (!index_values.count(i)) - return op.emitOpError() << "missing index " << i; - } - - if (inferred_item_shape) { - SmallVector expected_shape; - expected_shape.push_back(max_index + 1); - expected_shape.append(inferred_item_shape->begin(), - inferred_item_shape->end()); - - auto out_ty = op.getType().cast(); - auto expected_out_ty = - RankedTensorType::get(expected_shape, out_ty.getElementType()); - - if (!AreCastCompatible({out_ty, expected_out_ty})) { - return op.emitOpError() << "has invalid output type; should be " - "compatible with inferred type " - << expected_out_ty; - } - } - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// EinsumOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// * Arity of the op is at most two. -// -// TODO(hinsu): Verify einsum equation attribute. -static LogicalResult Verify(EinsumOp op) { - if (op.N() > 2) { - return op.emitOpError("supports at most two operands"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// EmptyOp -//===----------------------------------------------------------------------===// - -OpFoldResult EmptyOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "empty op has one operand"); - - Attribute attr = operands.front(); - if (!attr) return {}; - - auto int_attr = attr.cast(); - SmallVector out_shape; - for (const auto val : int_attr.getValues()) { - out_shape.push_back(val); - } - - auto type = getResult().getType().cast(); - auto etype = type.getElementType(); - - // We can not fold if the result is not static. - if (!type.hasStaticShape()) return {}; - - if (auto float_type = etype.dyn_cast()) { - auto out_type = RankedTensorType::get(out_shape, float_type); - return DenseElementsAttr::get(out_type, - {APFloat(float_type.getFloatSemantics())}); - } - - if (auto int_type = etype.dyn_cast()) { - auto out_type = RankedTensorType::get(out_shape, etype); - APInt val(int_type.getWidth(), 0, int_type.getSignedness()); - return DenseElementsAttr::get(out_type, val); - } - - return {}; -} - -//===----------------------------------------------------------------------===// -// EmptyTensorListOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(EmptyTensorListOp op) { - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { - return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); - } - - if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) { - return op.emitOpError("requires max_num_elements operand to be 0D tensor"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// EqualOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(EqualOp op) { - // If we allow inputs to have incompatible type, then nothing to do. - if (!op.incompatible_shape_error()) return success(); - - // Otherwise, check inputs are broadcastable. - return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( - op.getOperation()); -} - -void EqualOp::build(OpBuilder &builder, OperationState &result, Value x, - Value y, BoolAttr incompatible_shape_error) { - auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, - incompatible_shape_error); - return build(builder, result, result_type, x, y, incompatible_shape_error); -} - -//===----------------------------------------------------------------------===// -// ExpandDimsOp -//===----------------------------------------------------------------------===// - -Type InferExpandDimsOpType(Value input, Value dim) { - Type element_ty = input.getType().cast().getElementType(); - auto unranked_ty = UnrankedTensorType::get(element_ty); - - auto input_ty = input.getType().dyn_cast(); - if (!input_ty) return unranked_ty; - - DenseIntElementsAttr dim_attr; - if (!matchPattern(dim, m_Constant(&dim_attr)) || - dim_attr.getNumElements() != 1) - return unranked_ty; - int64_t dim_val = (*dim_attr.begin()).getSExtValue(); - int64_t input_rank = input_ty.getRank(); - - if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty; - if (dim_val < 0) dim_val += input_rank + 1; - - SmallVector shape = llvm::to_vector<4>(input_ty.getShape()); - shape.insert(shape.begin() + dim_val, 1); - return RankedTensorType::get(shape, element_ty); -} - -void ExpandDimsOp::build(OpBuilder &builder, OperationState &result, - Value input, Value dim) { - return build(builder, result, InferExpandDimsOpType(input, dim), input, dim); -} - -//===----------------------------------------------------------------------===// -// FakeQuantWithMinMaxArgsOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) { - // TODO(fengliuai): moving the following to an utility method. - const llvm::fltSemantics &semantics = op.min().getSemantics(); - float rmin, rmax; - if (&semantics == &APFloat::IEEEsingle()) { - rmin = op.min().convertToFloat(); - rmax = op.max().convertToFloat(); - } else { - rmin = op.min().convertToDouble(); - rmax = op.max().convertToDouble(); - } - // Range boundaries must be valid. - if (rmin >= rmax) { - return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + - "," + Twine(std::to_string(rmax)) + "]"); - } - int64_t num_bits = op.num_bits().getSExtValue(); - if (num_bits < 2 || num_bits > 16) { - return op.emitOpError( - "requires num_bits to be between 2 and 16, inclusive"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// FakeQuantWithMinMaxVarsOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { - auto min = GetRankedTensorTypeForOperand(op.min()); - if (min && !IsOfRankedFloatTensorType(min, 0)) - return op.emitOpError("requires min to be a 0d float tensor"); - - auto max = GetRankedTensorTypeForOperand(op.max()); - if (max && !IsOfRankedFloatTensorType(max, 0)) - return op.emitOpError("requires max to be a 0d float tensor"); - - int64_t num_bits = op.num_bits().getSExtValue(); - if (num_bits < 2 || num_bits > 16) { - return op.emitOpError( - "requires num_bits to be between 2 and 16, inclusive"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// FakeQuantWithMinMaxVarsPerChannelOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { - auto min = GetRankedTensorTypeForOperand(op.min()); - if (min && !IsOfRankedFloatTensorType(min, 1)) - return op.emitOpError("requires min to be a 1d float tensor"); - - auto max = GetRankedTensorTypeForOperand(op.max()); - if (max && !IsOfRankedFloatTensorType(max, 1)) - return op.emitOpError("requires max to be a 1d float tensor"); - - Value inputs = op.inputs(); - if (!HasRankAtLeast(inputs, 1)) - return op.emitError("requires inputs to be at least 1d float tensor"); - - int64_t num_bits = op.num_bits().getSExtValue(); - if (num_bits < 2 || num_bits > 16) { - return op.emitOpError( - "requires num_bits to be between 2 and 16, inclusive"); - } - - auto inputs_type = inputs.getType().dyn_cast(); - if (!inputs_type) return success(); - int depth = inputs_type.getDimSize(inputs_type.getRank() - 1); - if ((min && min.getDimSize(0) != depth) || - (max && max.getDimSize(0) != depth)) { - return op.emitOpError( - "requires min and max to have same size as last dimension of inputs"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// FillOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(FillOp op) { - if (!IsOfRankOrUnranked(op.dims(), 1)) - return op.emitOpError() << "requires dims to be a 1D tensor"; - if (!IsOfRankOrUnranked(op.value(), 0)) - return op.emitOpError() << "requires value to be a scalar"; - - return success(); -} - -static ShapedType InferFillOpType(Value dims, Value value) { - Type etype = value.getType().cast().getElementType(); - - DenseIntElementsAttr dims_attr; - if (!matchPattern(dims, m_Constant(&dims_attr))) { - return UnrankedTensorType::get(etype); - } - - llvm::SmallVector shape; - shape.reserve(dims_attr.getNumElements()); - for (const APInt dim : dims_attr.getValues()) { - shape.push_back(dim.getSExtValue()); - } - return RankedTensorType::get(shape, etype); -} - -void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, - Value value) { - FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); -} - -OpFoldResult FillOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "fill op has two operand"); - - auto type = getType().cast(); - // DenseElementsAttr that is used in this folder only supports int and float - // types. - // TODO(hinsu): Handle complex types once there is a attribute kind for - // complex. - if (!type.getElementType().isIntOrFloat()) return {}; - - auto value = operands[1].dyn_cast_or_null(); - if (!value) return {}; - - if (type.hasStaticShape()) - return DenseElementsAttr::get(type, value.getValue({})); - - auto dims = operands[0].dyn_cast_or_null(); - if (!dims) return {}; - - llvm::SmallVector shape; - shape.reserve(dims.getNumElements()); - for (const APInt dim : dims.getValues()) { - shape.push_back(dim.getSExtValue()); - } - type = RankedTensorType::get(shape, type.getElementType()); - - return DenseElementsAttr::get(type, value.getValue({})); -} - -//===----------------------------------------------------------------------===// -// FusedBatchNormGradOp -//===----------------------------------------------------------------------===// - -// TODO(b/150954845): Add benchmarks to verify that layout preference didn't -// change in the latest GPU generations. - -LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) { - return ::mlir::TF::UpdateDataFormat(data_format, this); -} - -StringRef FusedBatchNormGradV3Op::GetOptimalLayout( - const RuntimeDevices &devices) { - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - auto x_ty = x().getType().cast(); - const bool is_f16 = x_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // For all other data types prefer NCHW. - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// FusedBatchNormOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(FusedBatchNormOp op) { - auto x = GetRankedTensorTypeForOperand(op.x()); - if (x && !IsOfRankedFloatTensorType(x, 4)) - return op.emitOpError("requires x to be a 4D float tensor"); - - auto scale = GetRankedTensorTypeForOperand(op.scale()); - if (scale && !IsOfRankedFloatTensorType(scale, 1)) - return op.emitOpError("requires scale to be a 1D float tensor"); - - auto offset = GetRankedTensorTypeForOperand(op.offset()); - if (offset && !IsOfRankedFloatTensorType(offset, 1)) - return op.emitOpError("requires offset to be a 1D float tensor"); - - auto mean = GetRankedTensorTypeForOperand(op.mean()); - if (mean && !IsOfRankedFloatTensorType(mean, 1)) - return op.emitOpError("requires mean to be a 1D float tensor"); - - auto variance = GetRankedTensorTypeForOperand(op.variance()); - if (variance && !IsOfRankedFloatTensorType(variance, 1)) - return op.emitOpError("requires variance to be a 1D float tensor"); - - // TODO(antiagainst): check attributes - - return success(); -} - -//===----------------------------------------------------------------------===// -// FusedBatchNormV2Op / FusedBatchNormV3Op -//===----------------------------------------------------------------------===// - -template -static LogicalResult InferenceFoldOperandsPermutation( - ArrayRef permutation, Op *op) { - // FusedBatchNorm in training mode is a layout sentitive operation, and should - // have already assigned an optimal data format. - if (op->is_training()) return failure(); - return ::mlir::TF::FoldOperandsPermutation(permutation, op); -} - -template -static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) { - // In inference mode FusedBatchNorm is not sensitive to data layout. - if (!op->is_training()) return op->data_format(); - - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation())) - return op->data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - auto x_ty = op->x().getType().template cast(); - const bool is_f16 = x_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // For all other data types prefer NCHW. - return "NCHW"; -} - -LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation( - ArrayRef permutation) { - return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this); -} - -LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) { - return ::mlir::TF::UpdateDataFormat(data_format, this); -} - -StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) { - return ::mlir::TF::GetOptimalLayout(devices, this); -} - -LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation( - ArrayRef permutation) { - return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this); -} - -LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) { - return ::mlir::TF::UpdateDataFormat(data_format, this); -} - -StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { - return ::mlir::TF::GetOptimalLayout(devices, this); -} - -//===----------------------------------------------------------------------===// -// GatherV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(GatherV2Op op) { - int64_t batch_dims = op.batch_dims().getSExtValue(); - if (auto ty = op.indices().getType().dyn_cast()) { - int64_t rank = ty.getRank(); - if (batch_dims > rank || batch_dims < -rank) - return op.emitOpError() - << "batch_dims (" << batch_dims << ") must be in range [" << -rank - << ", " << rank + 1 << ")"; - if (batch_dims < 0) batch_dims += rank; - } - - if (!HasRankAtMost(op.axis(), 1)) - return op.emitOpError("requires axis to have rank at most 1"); - - DenseIntElementsAttr axis_attr; - if (matchPattern(op.axis(), m_Constant(&axis_attr))) { - int64_t axis = (*axis_attr.begin()).getSExtValue(); - if (auto ty = op.params().getType().dyn_cast()) { - int64_t rank = ty.getRank(); - if (axis >= rank || axis < -rank) - return op.emitOpError() << "axis (" << axis << ") must be in range [" - << -rank << ", " << rank << ")"; - if (axis < 0) axis += rank; - } - - if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) { - return op.emitOpError() << "requires axis (" << axis - << ") to be greater than or equal to batch_dims (" - << batch_dims << ")"; - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// IfOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(IfOp op) { - auto module = op.getParentOfType(); - auto then_fn = module.lookupSymbol(op.then_branch()); - if (!then_fn) - return op.emitOpError("then_branch refers to an undefined function : ") - << op.then_branch(); - auto else_fn = module.lookupSymbol(op.else_branch()); - if (!else_fn) - return op.emitOpError("else_branch refers to an undefined function : ") - << op.else_branch(); - auto then_fn_type = then_fn.getType(); - auto else_fn_type = else_fn.getType(); - - // Non-conditional operands starting with the second operand are passed to - // branches and should be pair-wise compatible with branches' inputs. - unsigned expected_num_inputs = op.getNumOperands() - 1; - if (then_fn_type.getNumInputs() != expected_num_inputs || - else_fn_type.getNumInputs() != expected_num_inputs) - return op.emitError("branches should have " + Twine(expected_num_inputs) + - " inputs"); - - for (unsigned i = 0; i < expected_num_inputs; ++i) { - auto operand_type = op.getOperand(i + 1).getType().cast(); - auto then_input_type = then_fn_type.getInput(i).cast(); - if (!AreCastCompatible({operand_type, then_input_type})) - return op.emitError( - llvm::formatv("then branch input type {0} is incompatible with " - "operand type {1} at index {2}", - then_input_type, operand_type, i)); - - auto else_input_type = else_fn_type.getInput(i).cast(); - if (!AreCastCompatible({operand_type, else_input_type})) - return op.emitError( - llvm::formatv("else branch input type {0} is incompatible with " - "operand type {1} at index {2}", - else_input_type, operand_type, i)); - - // If branches have incompatible input types that means that no tensor can - // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible({then_input_type, else_input_type})) - return op.emitError(llvm::formatv( - "branches inputs have incompatible types {0} and {1} at index {2}", - then_input_type, else_input_type, i)); - } - - // Branches' results should be pair-wise compatible with the op results. - unsigned expected_num_results = op.getNumResults(); - if (then_fn_type.getNumResults() != expected_num_results || - else_fn_type.getNumResults() != expected_num_results) - return op.emitError("branches should have " + Twine(expected_num_results) + - " results"); - - for (unsigned i = 0; i < expected_num_results; ++i) { - auto result_type = op.getResult(i).getType().cast(); - auto then_result_type = then_fn_type.getResult(i).cast(); - if (!AreCastCompatible({then_result_type, result_type})) - return op.emitError( - llvm::formatv("then branch result type {0} is incompatible with op " - "result type {1} at index {2}", - then_result_type, result_type, i)); - - auto else_result_type = else_fn_type.getResult(i).cast(); - if (!AreCastCompatible({else_result_type, result_type})) - return op.emitError( - llvm::formatv("else branch result type {0} is incompatible with op " - "result type {1} at index {2}", - else_result_type, result_type, i)); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// IfRegionOp -//===----------------------------------------------------------------------===// - -LogicalResult VerifyRegionResults(Operation *op, Region ®ion, - StringRef region_name) { - auto op_name = op->getName().getStringRef(); - // verify that op outputs match yield inputs - YieldOp yield = cast(region.front().getTerminator()); - unsigned expected_num_results = op->getNumResults(); - if (yield.getNumOperands() != expected_num_results) - return op->emitOpError() - << region_name + " should have same number (" << expected_num_results - << ") of results as " << op_name << " but has " - << yield.getNumOperands() << " results"; - - for (int idx : llvm::seq(0, expected_num_results)) { - auto op_result_type = op->getResult(idx).getType().cast(); - auto region_result_type = - yield.getOperand(idx).getType().cast(); - if (!AreCastCompatible({region_result_type, op_result_type})) - return op->emitError(llvm::formatv( - "{0} result type {1} is incompatible with {2} " - "result type {3} at index {4}", - region_name, region_result_type, op_name, op_result_type, idx)); - } - return success(); -} - -static LogicalResult Verify(IfRegionOp op) { - if (failed(VerifyRegionResults(op, op.then_branch(), "then"))) - return failure(); - if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) - return failure(); - return success(); -} - -//===----------------------------------------------------------------------===// -// InvertOp -//===----------------------------------------------------------------------===// - -void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// InvertPermutationOp -//===----------------------------------------------------------------------===// - -// Verifies that the input is 1D. -static LogicalResult Verify(InvertPermutationOp op) { - auto x_type = op.x().getType().cast(); - if (!x_type.hasRank()) return success(); - if (x_type.getShape().size() != 1) - return op.emitOpError() << "requires input x to be 1-dimensional"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// LeakyReluOp -//===----------------------------------------------------------------------===// - -OpFoldResult LeakyReluOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "leaky relu has one operand"); - - // leaky_relu(x, alpha: 1) -> x - if (alpha().convertToFloat() == 1.0f) return getOperand(); - - auto calculate = [&](FloatAttr arg) { - APFloat val = arg.getValue(); - if (val.isNegative()) val = alpha() * val; - return FloatAttr::get(arg.getType(), val); - }; - - if (auto arg = operands[0].dyn_cast_or_null()) { - return calculate(arg); - } else if (auto arg = operands[0].dyn_cast_or_null()) { - if (auto elementAttr = arg.getSplatValue().dyn_cast()) - return DenseElementsAttr::get(arg.getType(), calculate(elementAttr)); - } - return {}; -} - -//===----------------------------------------------------------------------===// -// LogOp -//===----------------------------------------------------------------------===// - -void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// ReadVariableOp -//===----------------------------------------------------------------------===// - -void ReadVariableOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// VarIsInitializedOp -//===----------------------------------------------------------------------===// - -namespace { - -/// Erase VarIsInitializedOp operations with no uses. This op has side effect on -/// resources (read-only), but can still be deleted if it has zero uses. -struct EraseDeadVarIsInitializedOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(VarIsInitializedOp op, - PatternRewriter &rewriter) const override { - if (!op.use_empty()) return failure(); - rewriter.eraseOp(op); - return success(); - } -}; -} // end anonymous namespace. - -void VarIsInitializedOp::getCanonicalizationPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); -} - -//===----------------------------------------------------------------------===// -// LogicalNotOp -//===----------------------------------------------------------------------===// - -void LogicalNotOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// MatrixBandPartOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(MatrixBandPartOp op) { - if (!HasRankAtLeast(op.input(), 2)) { - return op.emitOpError() - << "requires `input` to have rank of at least 2, but found " - << op.input().getType(); - } - if (!IsOfRankOrUnranked(op.num_lower(), 0)) { - return op.emitOpError() - << "requires `num_lower` to have 0 dimensions, but found " - << op.num_lower().getType(); - } - if (!IsOfRankOrUnranked(op.num_upper(), 0)) { - return op.emitOpError() - << "requires `num_upper` to have 0 dimensions, but found " - << op.num_upper().getType(); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// MaxOp -//===----------------------------------------------------------------------===// - -void MaxOp::build(OpBuilder &builder, OperationState &result, Value input, - Value reduction_indices, BoolAttr keep_dims) { - Type out_ty = - InferReductionOpType(input, reduction_indices, keep_dims, &builder); - build(builder, result, out_ty, input, reduction_indices, keep_dims); -} - -//===----------------------------------------------------------------------===// -// MaxPoolOp -//===----------------------------------------------------------------------===// - -LogicalResult MaxPoolOp::FoldOperandsPermutation( - ArrayRef permutation) { - return ::mlir::TF::FoldOperandsPermutation( - permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); -} - -//===----------------------------------------------------------------------===// -// MaxPoolGradOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(MaxPoolGradOp op) { - if (!IsOfRankOrUnranked(op.orig_input(), 4)) { - return op.emitOpError() << "requires orig_input to be rank 4"; - } - if (!IsOfRankOrUnranked(op.orig_output(), 4)) { - return op.emitOpError() << "requires orig_output to be rank 4"; - } - if (!IsOfRankOrUnranked(op.grad(), 4)) { - return op.emitOpError() << "requires grad to be rank 4"; - } - return success(); -} - -//===----------------------------------------------------------------------===// -// MeanOp -//===----------------------------------------------------------------------===// - -LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { - // Reduction indices must be defined by a constant operation. - auto reduction_op = - dyn_cast_or_null(reduction_indices().getDefiningOp()); - if (!reduction_op) return failure(); - - auto reductions_value = reduction_op.value().dyn_cast(); - if (!reductions_value) return failure(); - - // Prepare new reduction indices according to operand permutation. - SmallVector shuffled_reduction; - llvm::transform(reductions_value.getIntValues(), - std::back_inserter(shuffled_reduction), - [&](APInt idx) { return permutation[idx.getSExtValue()]; }); - - // Add constant operation with a new reduction indices. - OpBuilder builder(getOperation()); - auto type = mlir::RankedTensorType::get(shuffled_reduction.size(), - builder.getIntegerType(32)); - auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction); - auto shuffled_reduction_op = builder.create(getLoc(), values); - - // Use new reduction indices. - setOperand(1, shuffled_reduction_op); - - return success(); -} - -//===----------------------------------------------------------------------===// -// MulOp -//===----------------------------------------------------------------------===// - -OpFoldResult MulOp::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// NegOp -//===----------------------------------------------------------------------===// - -void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// NotEqualOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(NotEqualOp op) { - // If we allow inputs to have incompatible type, then nothing to do. - if (!op.incompatible_shape_error()) return success(); - - // Otherwise, check inputs are broadcastable. - return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( - op.getOperation()); -} - -void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x, - Value y, BoolAttr incompatible_shape_error) { - auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, - incompatible_shape_error); - return build(builder, result, result_type, x, y, incompatible_shape_error); -} - -//===----------------------------------------------------------------------===// -// OneHotOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(OneHotOp op) { - int64_t axis = op.axis().getSExtValue(); - - auto indices_ty = op.indices().getType().dyn_cast(); - if (indices_ty && - !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { - return op.emitOpError() - << "expected axis (" << axis << ") to be -1 or between [0, " - << indices_ty.getShape().size() << "]"; - } - - if (axis < -1) { - return op.emitOpError() << "expected axis (" << axis - << ") to be -1 or between [0, rank(indices()))"; - } - - if (!IsOfRankOrUnranked(op.depth(), 0)) { - return op.emitOpError() << "requires depth to be a scalar"; - } - if (!IsOfRankOrUnranked(op.on_value(), 0)) { - return op.emitOpError() << "requires on_value to be a scalar"; - } - if (!IsOfRankOrUnranked(op.off_value(), 0)) { - return op.emitOpError() << "requires off_value to be a scalar"; - } - - DenseIntElementsAttr depth_attr; - if (matchPattern(op.depth(), m_Constant(&depth_attr))) { - if (depth_attr.getType().getRank() != 0) - return op.emitOpError() << "requires depth to be a scalar"; - int64_t depth = depth_attr.getValue({}).getSExtValue(); - if (depth < 0) { - return op.emitOpError() << "depth must be non-negative, got: " << depth; - } - } - - return success(); -} - -static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, - Value off_value, IntegerAttr axis) { - int64_t axis_val = axis.getInt(); - Type element_ty = on_value.getType().cast().getElementType(); - auto unranked_ty = UnrankedTensorType::get(element_ty); - if (axis_val < -1) return unranked_ty; - - auto indices_ty = indices.getType().dyn_cast(); - if (!indices_ty) return unranked_ty; - - auto shape = llvm::to_vector<2>(indices_ty.getShape()); - if (axis_val == -1) axis_val = shape.size(); - - int64_t depth_val = ShapedType::kDynamicSize; - DenseIntElementsAttr depth_attr; - if (matchPattern(depth, m_Constant(&depth_attr)) && - depth_attr.getNumElements() == 1) - depth_val = (*depth_attr.begin()).getSExtValue(); - shape.insert(shape.begin() + axis_val, depth_val); - return RankedTensorType::get(shape, element_ty); -} - -void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices, - Value depth, Value on_value, Value off_value, - IntegerAttr axis) { - build(builder, result, - InferOneHotOpType(indices, depth, on_value, off_value, axis), indices, - depth, on_value, off_value, axis); -} - -//===----------------------------------------------------------------------===// -// PackOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(PackOp op) { - // TODO(hinsu): Convert variadic length attributes to derived attributes. - Operation::operand_range values = op.values(); - - if (failed(VerifyTypesCompatibility(values, - /*mask_one_dim=*/false, - op.getOperation()))) { - return failure(); - } - - int64_t inputs_rank = -1; - for (Value value : values) { - if (auto ty = value.getType().dyn_cast()) { - // Exit early as input types are verified to be compatible so all ranked - // tensors have the same rank. - inputs_rank = ty.getRank(); - break; - } - } - if (inputs_rank == -1) return success(); - - // The values can be packed along any of the dimensions between 0 and - // inputs rank, inclusive. Also, as the negative axis values wrap around so - // the axis value range is [-(R+1), R+1). - int64_t range_begin = -inputs_rank - 1; // Inclusive - int64_t range_end = inputs_rank + 1; // Exclusive - int64_t axis = op.axis().getSExtValue(); - if (axis < range_begin || axis >= range_end) { - return op.emitError() << "attribute 'axis' should be within range [" - << range_begin << ", " << range_end - << "); actual value: " << axis; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// PadOp -//===----------------------------------------------------------------------===// - -LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { - // Paddings must be defined by a constant operation. - auto paddings_op = dyn_cast_or_null(paddings().getDefiningOp()); - if (!paddings_op) return failure(); - - auto paddings_value = paddings_op.value().dyn_cast(); - if (!paddings_value || - paddings_value.getNumElements() != permutation.size() * 2) - return failure(); - - SmallVector shuffled_paddings(paddings_value.getNumElements()); - for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) { - size_t outer_idx = index_pair.index() / 2; - size_t inner_idx = index_pair.index() % 2; - - shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] = - index_pair.value().getSExtValue(); - } - - // Add constant operation with a new paddings. - OpBuilder builder(getOperation()); - auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(), - builder.getIntegerType(32)); - auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings); - auto shuffled_paddings_op = builder.create(getLoc(), values); - - // Use new paddings. - setOperand(1, shuffled_paddings_op); - - // Change the result type. - getResult().setType(ShuffleRankedTensorType(getResult().getType(), - ReversePermutation(permutation))); - - return success(); -} - -//===----------------------------------------------------------------------===// -// ParseExampleV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(ParseExampleV2Op op) { - // NOTE(mrry): This validates properties of an op that would previously be - // validated by the TensorFlow OpDef type checker. In addition to these - // checks, the shape inference function for ParseExampleV2 validates the - // consistency of the argument and result types. - - // Validate dense variadic input and output lengths. - // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we - // do not need to validate dense_defaults. - auto dense_types_count = - std::distance(op.Tdense().begin(), op.Tdense().end()); - auto dense_values_count = - std::distance(op.dense_values().begin(), op.dense_values().end()); - if (dense_values_count != dense_types_count) { - return op.emitError() << "output 'dense_values' should have same length " - << "as attribute 'Tdense'"; - } - - // Validate sparse variadic output lengths. - // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we - // do not need to validate sparse_values. - auto sparse_types_count = - std::distance(op.sparse_types().begin(), op.sparse_types().end()); - if (op.num_sparse() != sparse_types_count) { - return op.emitError() << "attribute 'num_sparse' should be the same as " - << "the length of attribute 'sparse_types'"; - } - if (op.sparse_indices().size() != sparse_types_count) { - return op.emitError() << "output 'sparse_indices' should have same length " - << "as attribute 'sparse_types'"; - } - if (op.sparse_shapes().size() != sparse_types_count) { - return op.emitError() << "output 'sparse_shapes' should have same length " - << "as attribute 'sparse_types'"; - } - - // Validate ragged variadic output lengths. - auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(), - op.ragged_value_types().end()); - auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(), - op.ragged_split_types().end()); - if (ragged_value_types_count != ragged_split_types_count) { - return op.emitError() << "attribute 'ragged_value_types' should have same " - << "length as attribute 'ragged_split_types'"; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// PartitionedCallOp -//===----------------------------------------------------------------------===// - -template -static LogicalResult VerifyPartitionedCall(OpClass op) { - auto module = op.template getParentOfType(); - SymbolRefAttr func = op.getAttr("f").template cast(); - - auto function = - dyn_cast_or_null(SymbolTable::lookupSymbolIn(module, func)); - - if (!function) { - return op.emitError("'f' attribute refers to an undefined function: ") - << func; - } - - FunctionType function_ty = function.getType(); - int func_arg_count = function_ty.getNumInputs(); - int arg_count = op.args().size(); - - if (arg_count != func_arg_count) { - return op.emitError() << "argument count mismatch: 'args' has " << arg_count - << " arguments, but '" << func << "' expects " - << func_arg_count; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// PowOp -//===----------------------------------------------------------------------===// - -OpFoldResult PowOp::fold(ArrayRef operands) { - auto constant_y = operands[1].dyn_cast_or_null(); - if (constant_y && constant_y.isSplat()) { - APFloat y_value = constant_y.getSplatValue(); - auto output_type = getType().cast(); - if (y_value.isZero() && output_type.hasStaticShape()) { - return DenseElementsAttr::get( - output_type, - FloatAttr::get(output_type.getElementType(), /*value=*/1.0)); - } - if (y_value.isExactlyValue(1.0)) { - return x(); - } - } - return {}; -} - -//===----------------------------------------------------------------------===// -// QrOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// -// * Input type, if ranked, must have at least 2 dimensions and at most -// INT32_MAX dimensions. -// -static LogicalResult Verify(QrOp op) { - auto ttype = op.input().getType().cast(); - if (!ttype.hasRank()) return success(); - if (!HasRankAtLeast(op.input(), 2)) - return op.emitOpError( - "requires ranked input tensor to be of rank 2 or more"); - if (!HasRankAtMost(op.input(), std::numeric_limits::max())) - return op.emitOpError( - "requires ranked input tensor to be of rank INT32_MAX or less"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// ReciprocalOp -//===----------------------------------------------------------------------===// - -void ReciprocalOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// RandomUniformOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(RandomUniformOp op) { - if (!IsOfRankOrUnranked(op.shape(), 1)) - return op.emitOpError("shape must be 1D tensor"); - return success(); -} - -//===----------------------------------------------------------------------===// -// RangeOp -//===----------------------------------------------------------------------===// - -void RangeOp::build(OpBuilder &builder, OperationState &result, Value start, - Value limit, Value delta) { - assert(start.getType() == limit.getType()); - assert(start.getType() == delta.getType()); - DenseIntElementsAttr start_val; - DenseIntElementsAttr limit_val; - DenseIntElementsAttr delta_val; - if (matchPattern(start, m_Constant(&start_val)) && - matchPattern(limit, m_Constant(&limit_val)) && - matchPattern(delta, m_Constant(&delta_val))) { - auto size = llvm::APIntOps::RoundingSDiv( - *limit_val.begin() - *start_val.begin(), *delta_val.begin(), - llvm::APInt::Rounding::DOWN); - return RangeOp::build( - builder, result, - RankedTensorType::get( - size.getSExtValue(), - start.getType().cast().getElementType()), - start, limit, delta); - } - return RangeOp::build( - builder, result, - RankedTensorType::get( - {-1}, start.getType().cast().getElementType()), - start, limit, delta); -} -//===----------------------------------------------------------------------===// -// RankOp -//===----------------------------------------------------------------------===// - -void RankOp::build(OpBuilder &builder, OperationState &result, Value input) { - return RankOp::build(builder, result, - RankedTensorType::get({}, builder.getIntegerType(32)), - input); -} - -// This will create a constant value for RankOp of a ranked tensor. -OpFoldResult RankOp::fold(ArrayRef operands) { - auto type = input().getType(); - auto ranked_type = type.dyn_cast(); - if (!ranked_type) return {}; - - auto output_type = getType().cast(); - int32_t rank = ranked_type.getRank(); - return DenseIntElementsAttr::get(output_type, rank); -} - -//===----------------------------------------------------------------------===// -// RealDivOp -//===----------------------------------------------------------------------===// - -void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult RealDivOp::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// ReshapeOp -//===----------------------------------------------------------------------===// - -// TODO(b/128020684): Verify the output type. -static LogicalResult Verify(ReshapeOp op) { - auto shape_type = op.shape().getType().cast(); - if (!shape_type.hasRank()) return success(); - if (shape_type.getRank() != 1) - return op.emitOpError("shape must be 1D tensor"); - auto rank_by_shape = shape_type.getShape()[0]; - auto type_of_tensor = op.tensor().getType().cast(); - // No compile time verification for unknown sized shape. - if (rank_by_shape == -1 || !type_of_tensor.hasStaticShape()) return success(); - int64_t num_by_tensor = type_of_tensor.getNumElements(); - - auto out_ty = op.getType().dyn_cast(); - if (out_ty && out_ty.hasStaticShape()) { - int64_t num_output_elements = out_ty.getNumElements(); - if (num_by_tensor != num_output_elements) - return op.emitOpError() - << "number of output elements (" << num_output_elements - << ") does not match expected number of elements (" - << num_by_tensor << ")"; - } - - // Check values if constant shape. No compiling time verification for - // non-constant shape. - auto *shape_op = op.shape().getDefiningOp(); - if (!shape_op) return success(); - Attribute shape_cst; - if (!matchPattern(shape_op, m_Constant(&shape_cst))) return success(); - auto shape_cst_attr = shape_cst.dyn_cast(); - if (!shape_cst_attr) return op.emitOpError("shape must be a valid tensor"); - - if (auto opaque_attr = shape_cst_attr.dyn_cast()) { - opaque_attr.decode(shape_cst_attr); - } - - // We know the shape is a 1-D Tensor, then let us get the number of - // elements it implies. - unsigned num_by_shape = 1; - unsigned unknown_dim_count = 0; - for (int i = 0, e = rank_by_shape; i != e; ++i) { - auto num = shape_cst_attr.getValue(i).getInt(); - // The dimension size value can be -1, and that the real size needs to - // be computed so that the total size remains constant. At most one - // component of shape can be -1. - if (num == -1) { - if (++unknown_dim_count > 1) { - return op.emitOpError("more than one component of shape are -1"); - } - } else { - num_by_shape *= num; - } - } - // If there is one component of shape is -1, the dimension should be - // computed so that the total size remains constant. - if (unknown_dim_count == 1) { - if (num_by_tensor % num_by_shape != 0) - return op.emitOpError( - "one component of shape is -1 but couldn't infer the dimension"); - return success(); - } - // If the elements by the tensor and implies by the shape don't match, - // fail this static check. - if (num_by_tensor != num_by_shape) { - return op.emitOpError( - "mismatch in tensor elements and shape implied elements"); - } - return success(); -} - -void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, - Value shape) { - auto ttype = tensor.getType().cast(); - auto etype = ttype.getElementType(); - - auto unranked = [&builder, etype, &result, shape, tensor]() { - return ReshapeOp::build(builder, result, UnrankedTensorType::get(etype), - tensor, shape); - }; - - // If tensor is unranked then we have no info about output of shape. - if (!ttype.hasRank()) return unranked(); - - DenseIntElementsAttr attr_shape; - if (matchPattern(shape, m_Constant(&attr_shape))) { - llvm::SmallVector const_shape; - const_shape.reserve(attr_shape.getNumElements()); - - // Detect if reshape output shape is folded. - bool flatten = false; - int unknown_index = -1; - // The product of constant shape argument excluding unknown dimension. - int64_t product_cshape = 1; - for (auto e : llvm::enumerate(attr_shape)) { - int64_t val = e.value().getSExtValue(); - if (IsUnknownDimOrRank(val)) { - if (flatten) { - mlir::emitError(result.location) - << "only one unknown dimension allowed"; - return; - } - flatten = true; - unknown_index = e.index(); - } else { - product_cshape *= val; - } - const_shape.push_back(val); - } - - // Compute the value of the unknown dimension. - if (flatten) { - // Compute number of elements in tensor shape. - auto tshape = ttype.getShape(); - int64_t product_tshape = std::accumulate(tshape.begin(), tshape.end(), 1, - std::multiplies()); - // Set the unknown dimension such that total number of elements remain - // constant. - // Note: The case where the ratio is not integral, and so the total size - // of reshape not constant, is checked in verify function. - const_shape[unknown_index] = product_tshape / product_cshape; - } - return ReshapeOp::build(builder, result, - RankedTensorType::get(const_shape, etype), tensor, - shape); - } - return unranked(); -} - -void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult ReshapeOp::fold(ArrayRef operands) { - Value tensor = this->tensor(); - Value shape = this->shape(); - - // Fold reshape if operand and result types are the same and all dimensions - // are statically known (no-op reshape). - // TODO(ezhulenev): Add the same folding for BroadcastToOp. - auto result_ty = getType().dyn_cast(); - if (result_ty && result_ty.hasStaticShape() && - result_ty == tensor.getType()) { - return tensor; - } - - // Fold reshape if the shape is computed from the input tensor: - // - // %shape = tf.Shape(%arg) // [? x ...] - // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value - // %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...] - // %reshape = tf.Reshape(%arg, %new_shape) // this is no-op - // - // Where `...` are some statically known dimensions. In this case reshape is - // a no-op and can be replaced by %arg (assuming `...` are equal). - auto pack_op = dyn_cast_or_null(shape.getDefiningOp()); - if (!pack_op || pack_op.values().size() < 2) return {}; - - // Dimensions packed along axis = 0 (pack scalars into vector). - if (pack_op.axis().getSExtValue() != 0) return {}; - - // First packed value is defined by a strided slice operation. - auto slice_op = - dyn_cast_or_null(pack_op.values()[0].getDefiningOp()); - if (!slice_op) return {}; - - // Input to the slice op is defined by shape operation. - auto shape_op = dyn_cast_or_null(slice_op.input().getDefiningOp()); - if (!shape_op || shape_op.input() != tensor) return {}; - - // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing - // scalar value from input vector). - if (slice_op.begin_mask().getSExtValue() != 0 || - slice_op.ellipsis_mask().getSExtValue() != 0 || - slice_op.end_mask().getSExtValue() != 0 || - slice_op.new_axis_mask().getSExtValue() != 0 || - slice_op.shrink_axis_mask().getSExtValue() != 1) - return {}; - - // Returns a value if the `value` is defined by a ConstOp with a single - // integer element in it and has an expected rank. - auto get_value = [](Value value, int expected_rank) -> Optional { - auto const_op = dyn_cast_or_null(value.getDefiningOp()); - if (!const_op) return None; - - auto value_attr = const_op.value().dyn_cast(); - if (!value_attr || value_attr.getNumElements() != 1) return None; - - auto value_ty = value_attr.getType(); - if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None; - - auto splat = value_attr.getSplatValue(); - return splat.getValue().getSExtValue(); - }; - - // All other packed values are scalar constants. - SmallVector packed_dims; - packed_dims.reserve(pack_op.values().size() - 1); - for (Value operand : llvm::drop_begin(pack_op.values(), 1)) { - if (auto dim = get_value(operand, /*expected_rank=*/0)) { - packed_dims.push_back(*dim); - } else { - return {}; - } - } - - // Slice exactly the first shape dimension: - // begin = [0] end = [1], strides = [1] - auto begin = get_value(slice_op.begin(), /*expected_rank=*/1); - auto end = get_value(slice_op.end(), /*expected_rank=*/1); - auto strides = get_value(slice_op.strides(), /*expected_rank=*/1); - if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() || - *begin != 0 || *end != 1 || *strides != 1) - return {}; - - // First tensor dimension is dynamic. - auto arg_ty = tensor.getType().dyn_cast(); - if (!arg_ty || arg_ty.getNumDynamicDims() != 1 || !arg_ty.isDynamicDim(0)) - return {}; - - // Argument tensor rank is equal to the number of packed dimensions. - if (arg_ty.getRank() != pack_op.values().size()) return {}; - - // All other dimensions are statically known and equal to packed dims. - auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1); - if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin())) - return {}; - - return tensor; -} - -//===----------------------------------------------------------------------===// -// SelectOp -//===----------------------------------------------------------------------===// - -void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -// Verifies a few extra requirements on SelectOp: -// (1) `then` and `else` must have same shape -// (2) At least one of the following must be true: -// (a) `cond` has the same rank as `then` and `else` -// (b) `cond` is a scalar -// (c) `cond` is a vector AND `then` and `else` are non-scalar with their -// first dimension equal to `cond`. -static LogicalResult Verify(SelectOp op) { - auto then_tensor = op.t().getType().cast(); - auto else_tensor = op.e().getType().cast(); - // Check (1). - if (!AreCastCompatible({then_tensor, else_tensor})) - return op.emitOpError() << "requires t and e have compatible shapes"; - - // Get data rank (if exists). - int data_rank; - // If data is unranked or data_rank is 0, this will remain -2. Otherwise - // refers to first dimension of then and/or else. - int data_first_dim = -2; - bool then_has_rank = then_tensor.hasRank(); - bool else_has_rank = else_tensor.hasRank(); - if (then_has_rank && else_has_rank) { - data_rank = then_tensor.getRank(); - if (then_tensor.getRank() > 0) - data_first_dim = then_tensor.getShape().front(); - if (else_tensor.getRank() > 0) - data_first_dim = std::max( - static_cast(else_tensor.getShape().front()), data_first_dim); - } else if (then_has_rank) { - data_rank = then_tensor.getRank(); - if (then_tensor.getRank() > 0) - data_first_dim = then_tensor.getShape().front(); - } else if (else_has_rank) { - data_rank = else_tensor.getRank(); - if (else_tensor.getRank() > 0) - data_first_dim = else_tensor.getShape().front(); - } else { - // Neither has a rank. - return success(); - } - - auto cond_tensor = op.condition().getType().dyn_cast(); - if (!cond_tensor) return success(); - auto cond_rank = cond_tensor.getRank(); - // Check (2a) and (2b). - if (cond_rank == 0 || cond_rank == data_rank) return success(); - // Check (2c). - if (cond_rank == 1) { - auto cond_shape = cond_tensor.getShape().front(); - if (data_rank == 0) { - return op.emitOpError() - << "requires that t and e are nonscalar when pred is a vector"; - } - // We know `data` tensor has a rank of at least 1. - if (data_first_dim != -1 && cond_shape != -1 && - data_first_dim != cond_shape) { - return op.emitOpError() << "requires that, when pred is a vector, the " - "shape matches the first dimension of t and e"; - } - return success(); - } - // None of (2a,b,c) were true; fail. - return op.emitOpError() << "requires that pred is a scalar OR has the same " - "rank as t and e OR is a vector"; -} - -//===----------------------------------------------------------------------===// -// SelectV2Op -//===----------------------------------------------------------------------===// - -static Type InferSelectV2OpType(Value condition, Value e, Value t) { - Type element_ty = e.getType().cast().getElementType(); - auto unranked_ty = UnrankedTensorType::get(element_ty); - - Type broadcasted_ty = - OpTrait::util::getBroadcastedType(e.getType(), t.getType()); - if (!broadcasted_ty) return unranked_ty; - - auto cond_ranked_ty = condition.getType().dyn_cast(); - auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast(); - if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; - - // Explicitly get broadcasted output type as element types of condition may - // not be same as the broadcated type's element type. - SmallVector result_shape; - if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(), - broadcasted_ranked_ty.getShape(), - result_shape)) - return unranked_ty; - return RankedTensorType::get(result_shape, element_ty); -} - -void SelectV2Op::build(OpBuilder &builder, OperationState &result, - Value condition, Value e, Value t) { - build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t); -} - -//===----------------------------------------------------------------------===// -// ShapeOp -//===----------------------------------------------------------------------===// - -namespace { -// Validates Shape/ShapeN/VariableShape operand and associated result types. -LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, - Type result_type, - int variadic_idx = -1) { - std::string variadic_idx_str = - variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str(); - - auto result_ranked_type = result_type.dyn_cast(); - if (!result_ranked_type) return success(); - if (result_ranked_type.getShape().size() != 1) - return op->emitOpError("requires 1D type for result") << variadic_idx_str; - - auto operand_ranked_type = operand_type.dyn_cast_or_null(); - if (operand_ranked_type) { - // The operand is a ranked tensor. - if (result_ranked_type.hasStaticShape() && - !operand_ranked_type.getShape().empty() && - result_ranked_type.getDimSize(0) != - operand_ranked_type.getShape().size()) - return op->emitOpError("requires dimension size of result") - << variadic_idx_str << " to match rank of operand" - << variadic_idx_str; - } else if (result_ranked_type.hasStaticShape()) { - // The operand is an unranked tensor, print a warning if the result - // is static. - // Note: We do not handle this situation as an error, this would be too - // restrictive due to incompleteness of shape inference at this point. - op->emitWarning("has static shape result") - << variadic_idx_str << " for unranked operand" << variadic_idx_str; - } - - Type element_type = result_ranked_type.getElementType(); - if (!element_type.isSignlessInteger(32) && - !element_type.isSignlessInteger(64)) - return op->emitOpError("requires int32 or int64 return type for result") - << variadic_idx_str; - - return success(); -} -} // anonymous namespace - -static LogicalResult Verify(ShapeOp op) { - return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType()); -} - -// Converts shape of the given type to attribute if it is of ranked tensor type. -// Returned attribute has integer elements of the given width. -static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { - auto ranked_ty = input_ty.dyn_cast(); - if (!ranked_ty || !ranked_ty.hasStaticShape()) return {}; - - auto shape = ranked_ty.getShape(); - int rank = shape.size(); - - SmallVector dimensions; - dimensions.reserve(rank); - for (int i = 0; i < rank; ++i) - dimensions.push_back(APInt(out_width, shape[i])); - - auto result_type = RankedTensorType::get( - {rank}, IntegerType::get(out_width, input_ty.getContext())); - return DenseElementsAttr::get(result_type, dimensions); -} - -OpFoldResult ShapeOp::fold(ArrayRef operands) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - return ConvertShapeToAttr(getOperand().getType(), width); -} - -void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input, - BoolAttr use32Bit) { - auto rankedTensorType = input.getType().dyn_cast(); - int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; - auto out_type = use32Bit.getValue() ? builder.getIntegerType(32) - : builder.getIntegerType(64); - return ShapeOp::build(builder, result, - RankedTensorType::get({rank}, out_type), input); -} - -//===----------------------------------------------------------------------===// -// ShapeNOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(ShapeNOp op) { - const size_t num_tensors = op.N(); - - if (op.getNumOperands() != num_tensors) - return op.emitOpError() << "requires " << num_tensors << " operand(s), got " - << op.getNumOperands() << " operand(s)"; - - if (op.getNumResults() != num_tensors) - return op.emitOpError() << "requires " << num_tensors << " result(s), got " - << op.getNumResults() << " result(s)"; - - for (auto i : llvm::seq(0, num_tensors)) { - auto verification = VerifyShapeOperandAndResult( - op, op.getOperand(i).getType(), op.getResult(i).getType(), i); - if (failed(verification)) return verification; - } - - return success(); -} - -LogicalResult ShapeNOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - if (getNumOperands() == 0) return success(); - int width = - getType(0).cast().getElementType().getIntOrFloatBitWidth(); - - for (Type input_ty : getOperandTypes()) { - OpFoldResult result = ConvertShapeToAttr(input_ty, width); - if (!result) return failure(); - - results.push_back(result); - } - return success(); -} - -// TODO(hinsu): Add canonicalization pattern for ShapeN ops that don't have all -// static input shapes. Replacing output values corresponding to static input -// types may enable optimizations in users of the values. - -//===----------------------------------------------------------------------===// -// SizeOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// -// * Input type, if is a ranked tensor, has at most INT32_MAX dimensions. -// -static LogicalResult Verify(SizeOp op) { - if (!HasRankAtMost(op.input(), std::numeric_limits::max())) - return op.emitOpError( - "requires ranked input tensor to be of rank INT32_MAX or less"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// SliceOp -//===----------------------------------------------------------------------===// - -// Verifies that: -// -// - operands begin and size are 1D with the same number of elements. -// - if the input is a ranked tensor, the rank of the input equals the number -// of elements in operands begin and size. -// - if begin are constants, that -// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i] -// - if begins aren't constant but the input is a ranked tensor, that -// size[i] <= input_ty.getShape()[i] -// -static LogicalResult Verify(SliceOp op) { - RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin()); - if (begin_ty && begin_ty.getRank() != 1) { - return op.emitOpError() << "requires begin operand to be 1D tensor"; - } - - RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size()); - if (size_ty && size_ty.getRank() != 1) { - return op.emitOpError() << "requires size operand to be 1D tensor"; - } - - if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() || - !size_ty.hasStaticShape()) - return success(); - - if (begin_ty.getNumElements() != size_ty.getNumElements()) { - return op.emitOpError() << "requires begin and size operands to have the" - " same number of elements"; - } - - auto input_ty = op.input().getType().dyn_cast(); - if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { - return op.emitOpError() << "requires number of elements in begin and size" - "are equal to input rank"; - } - - DenseIntElementsAttr begin_indices; - if (matchPattern(op.begin(), m_Constant(&begin_indices))) { - DenseIntElementsAttr slice_sizes; - bool constant_slice_sizes = - matchPattern(op.size(), m_Constant(&slice_sizes)); - int dim = 0; - for (const APInt &raw_begin_index : begin_indices.getValues()) { - int64_t begin_index = raw_begin_index.getSExtValue(); - int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1; - int64_t slice_size = constant_slice_sizes - ? slice_sizes.getValue(dim).getSExtValue() - : 0; - if (slice_size == -1 && input_size != -1) { - slice_size = input_size - begin_index; - } - if (begin_index < 0 || - (input_size != -1 && begin_index + slice_size > input_size)) { - return op.emitOpError() - << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di"; - } - ++dim; - } - } else if (input_ty) { - // If the inputs are ranked, we can do a few more sanity checks. - DenseIntElementsAttr slice_sizes; - if (matchPattern(op.size(), m_Constant(&slice_sizes))) { - auto input_shape = input_ty.getShape(); - for (int64_t i = 0; i < input_ty.getRank(); ++i) { - int64_t slice_size = slice_sizes.getValue(i).getInt(); - int64_t input_size = input_shape[i]; - if (slice_size != -1 && input_size != -1 && slice_size > input_size) { - return op.emitOpError() << "requires size[i] <= Di, even if begin[i] " - "is unknown at compile time"; - } - } - } - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// SoftmaxOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(SoftmaxOp op) { - if (!HasRankAtLeast(op.logits(), 1)) { - return op.emitOpError("requires operand to have rank at least 1"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// SoftmaxCrossEntropyWithLogitsOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// -// * Input types are broadcast compatible and the broadcasted type has rank two. -// -static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { - auto broadcasted_ty = OpTrait::util::getBroadcastedType( - op.features().getType(), op.labels().getType()) - .dyn_cast_or_null(); - if (!broadcasted_ty || - (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) - return op.emitOpError( - "requires features and labels to be broadcast compatible to rank two"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// SparseSoftmaxCrossEntropyWithLogitsOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) { - if (!IsOfRankOrUnranked(op.features(), 2)) { - return op.emitOpError("requires features operand of rank two"); - } - if (!IsOfRankOrUnranked(op.labels(), 1)) { - return op.emitOpError("requires labels operand of rank one"); - } - auto features_ty = op.features().getType().dyn_cast(); - auto labels_ty = op.labels().getType().dyn_cast(); - if (features_ty && labels_ty) { - int64_t features_batches = features_ty.getDimSize(0); - int64_t labels_batches = labels_ty.getDimSize(0); - if (!ShapedType::isDynamic(features_batches) && - !ShapedType::isDynamic(labels_batches) && - features_batches != labels_batches) - return op.emitOpError( - "requires features and labels with matching first dimension"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// SplitOp -//===----------------------------------------------------------------------===// - -// Verifies the input and split dimension operands for tf.Split/tf.SplitV. -// Writes the split dimension's index (adjusted with input rank) via `dim_index` -// if it's a constant. -template -LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { - *dim_index = llvm::None; - - Value split_dim = op.split_dim(); - if (auto split_dim_type = split_dim.getType().dyn_cast()) - if (split_dim_type.getRank() != 0) - return op.emitOpError( - "split dimension should be an integer scalar tensor"); - - // We can perform further verification if the input tensor to be split has - // known rank and the split dimension tensor is a constant. - - auto input_type = op.value().getType().template dyn_cast(); - if (!input_type) return success(); - - int64_t input_rank = input_type.getRank(); - if (input_rank == 0) - return op.emitOpError("cannot split scalar input tensor"); - - DenseIntElementsAttr split_dim_attr; - if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success(); - - int64_t index = (*split_dim_attr.begin()).getSExtValue(); - - if (index + input_rank < 0 || index >= input_rank) { - return op.emitOpError("split dimension must be in range [-") - << input_rank << ", " << input_rank << ")"; - } - - if (index < 0) index += input_rank; - *dim_index = index; - - return success(); -} - -static LogicalResult Verify(SplitOp op) { - Optional dim_index; - if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); - if (!dim_index) return success(); - - int64_t input_dim_size = - op.value().getType().cast().getDimSize(*dim_index); - if (input_dim_size == ShapedType::kDynamicSize) return success(); - - if (input_dim_size % op.getNumResults() != 0) - return op.emitOpError("dimension #") - << *dim_index << " not divisible by the number of result tensors"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// SplitVOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(SplitVOp op) { - auto split_sizes_type = - op.size_splits().getType().dyn_cast(); - if (!split_sizes_type) return success(); - - if (split_sizes_type.getRank() != 1 || - split_sizes_type.getDimSize(0) != op.getNumResults()) - return op.emitOpError("split sizes should be a 1D tensor of ") - << op.getNumResults() << " elements"; - - Optional dim_index = 0; - if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); - if (!dim_index) return success(); - - int64_t input_dim_size = - op.value().getType().cast().getDimSize(*dim_index); - if (input_dim_size == ShapedType::kDynamicSize) return success(); - - // If split sizes come from a constant, they must sum to the dimension size - // along split_dim, and we can have no more than one dynamic dimension. - DenseIntElementsAttr split_sizes_attr; - if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) - return success(); - - int64_t total_dim_size = 0; // Total dimension size assigned to splits - llvm::Optional dynamic_dim_index; - - SmallVector split_sizes; - split_sizes.reserve( - split_sizes_attr.getType().cast().getNumElements()); - - for (auto dim : llvm::enumerate(split_sizes_attr)) { - int64_t dim_val = dim.value().getSExtValue(); - split_sizes.push_back(dim_val); - if (dim_val == ShapedType::kDynamicSize) { - // We cannot have more than one dynamic dimension. - if (dynamic_dim_index) - return op.emitOpError( - "cannot have more than one dynamic dimension in split sizes"); - dynamic_dim_index = dim.index(); - } else { - total_dim_size += dim_val; - } - } - - if (!dynamic_dim_index && total_dim_size != input_dim_size) - return op.emitOpError( - "split sizes must sum up to the dimension size along split " - "dimension, found ") - << total_dim_size << " vs " << input_dim_size; - - if (dynamic_dim_index && total_dim_size > input_dim_size) - return op.emitOpError( - "split sizes must sum up to be less than or equal to the " - "dimension size along split dimension, found ") - << total_dim_size << " vs " << input_dim_size; - - return success(); -} - -//===----------------------------------------------------------------------===// -// SquareOp -//===----------------------------------------------------------------------===// - -void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// SubOp -//===----------------------------------------------------------------------===// - -void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult SubOp::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// SumOp -//===----------------------------------------------------------------------===// - -void SumOp::build(OpBuilder &builder, OperationState &result, Value input, - Value reduction_indices, BoolAttr keep_dims) { - Type out_ty = - InferReductionOpType(input, reduction_indices, keep_dims, &builder); - build(builder, result, out_ty, input, reduction_indices, keep_dims); -} - -//===----------------------------------------------------------------------===// -// StridedSliceOp -//===----------------------------------------------------------------------===// - -// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to -// tf.SliceOp if both of the following are true: -// - All strides have a known value equal to 1 -// - No masks are set (or masks can be applied by transforming the inputs to -// Slice) - -// Verifies that, -// -// - begin, end and strides operands are 1D and they have the same number of -// elements. Here, the number of elements should be less than 32 to support -// 32-bit mask attributes. -// - None of the strides values are zero. -// - Ellipsis mask can have at most one bit set. - -template -static LogicalResult VerifyStridedSliceBase(OpTy op) { - // Expected size for operands begin, end and strides vector operands. - int64_t expected_size = -1; - - for (Value val : {op.begin(), op.end(), op.strides()}) { - auto operand_ty = val.getType().dyn_cast(); - if (!operand_ty || !operand_ty.hasStaticShape()) { - // TensorFlow constant ops may have non-static shape because the shape is - // not propagated during constant folding. If the defining op for this - // operand is a constant op, use the constant op's attribute to get the - // actual shape. - DenseIntElementsAttr attr; - if (!matchPattern(val, m_Constant(&attr))) continue; - operand_ty = attr.getType(); - } - - if (operand_ty.getRank() != 1) - return op.emitOpError() - << "requires begin, end and strides to be 1D tensors"; - - int64_t length = operand_ty.getDimSize(0); - if (length == -1) continue; - - if (expected_size == -1) { - // This op uses 32-bit masks. - if (length >= 32) - return op.emitOpError( - "requires begin, end and strides operands with less than 32 " - "elements"); - - expected_size = length; - } else if (length != expected_size) { - return op.emitOpError() << "requires begin, end and strides to have the " - "same number of elements"; - } - } - - // If strides are constants, verify that none of the element is zero. - DenseIntElementsAttr strides; - if (matchPattern(op.strides(), m_Constant(&strides))) { - if (llvm::is_contained(strides.getValues(), 0)) - return op.emitOpError("requires non-zero strides"); - } - - // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there - // exists only no more than one ellipsis. - uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue(); - if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask)) - return op.emitOpError("cannot have multiple ellipses"); - - return success(); -} - -// Clamps the given `val`: returns `low` if `val` is less than `low`; returns -// `high` if `high` is less than `val`; otherwise returns `val`. -template -constexpr const T &Clamp(const T &val, const T &low, const T &high) { - assert(!(high < low)); - return (val < low) ? low : (high < val) ? high : val; -} - -// Checks if the `index` bit of `val` is set. -template -constexpr bool IsSet(const T &val, unsigned index) { - return (val & (1 << index)) != 0; -} - -// Sets the `index` bit of `val`. -template -constexpr void Set(T &val, unsigned index) { - val |= (1 << index); -} - -// Unset the `index` bit of `val`. -template -constexpr void Unset(T &val, unsigned index) { - val &= ~(1 << index); -} - -// Copy the `src_index` bit of `src` to `dst_index` bit of `dst`. -template -constexpr void CopyBit(const T &src, unsigned src_index, T &dst, - unsigned dst_index) { - if (IsSet(src, src_index)) - Set(dst, dst_index); - else - Unset(dst, dst_index); -} - -// The sparse spec of strided slice does not correspond to the number of -// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2, -// 4, 8) would have dims = 2. -struct SparseSliceSpec { - int64_t dims; - int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; - const ArrayRef &begin; - const ArrayRef &end; - const ArrayRef &strides; -}; - -// The dense spec of strided slice is the canonicalized version of sparse spec. -// The number of dimensions of dense spec correspond to the number of dimensions -// in operand tensor. -struct DenseSliceSpec { - int64_t dims; - int32_t begin_mask, end_mask, shrink_axis_mask; - SmallVectorImpl &begin; - SmallVectorImpl &end; - SmallVectorImpl &strides; -}; - -// Make a sparse spec into a dense index spec. -// The sparse spec does not correspond to the number of dimensions -// Make a dense spec that corresponds to the number of dimensions -// -// For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then -// we need to produce the missing begin_mask, end_mask for the first two -// dimensions i.e. foo[:, :, 3:, 2]. -static void BuildDenseSliceSpec(const SparseSliceSpec &sparse, - DenseSliceSpec *dense) { - // Build expanded dense begin, end, strides, begin_mask, end_mask, and - // shrink_axis_mask. - dense->begin.resize(dense->dims); - dense->end.resize(dense->dims); - dense->strides.resize(dense->dims); - dense->begin_mask = 0; - dense->end_mask = 0; - dense->shrink_axis_mask = 0; - - // Count number of new_axis after ellipsis. This helps in calculating the - // number of dimensions ellipsis represents in the sparse spec. - bool ellipsis_seen = false; - int num_new_axis_after_ellipsis = 0; - for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) { - if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index)) - num_new_axis_after_ellipsis++; - if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true; - } - - int dense_index = 0; - for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) { - if (IsSet(sparse.new_axis_mask, sparse_index)) continue; - if (IsSet(sparse.ellipsis_mask, sparse_index)) { - auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) + - 1 + num_new_axis_after_ellipsis, - dense->dims); - // Expand ellipsis into the appropriate dense indices. From current index - // until next_index, all dimensions would have begin and end masks set and - // stride 1, i.e., get all elements in those dimensions. - for (; dense_index < next_index; ++dense_index) { - dense->begin[dense_index] = dense->end[dense_index] = 0; - dense->strides[dense_index] = 1; - Set(dense->begin_mask, dense_index); - Set(dense->end_mask, dense_index); - } - continue; - } - assert(dense_index < dense->dims); - // Copy over the sparse indices to dense indices if ellipsis_mask and - // new_axis_mask are not set. - dense->begin[dense_index] = sparse.begin[sparse_index]; - dense->end[dense_index] = sparse.end[sparse_index]; - dense->strides[dense_index] = sparse.strides[sparse_index]; - CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index); - CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index); - CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask, - dense_index); - dense_index++; - } -} - -// For the given `input_shape`, calculates the sliced shape using the given -// `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and -// `shrink_axis_mask` masks. Updates the result back to `input_shape`. If -// `shrink_axis_mask` is not zero, this function will not drop the corresponding -// dimensions in `input_shape`; it will turn them into 1s. At the same time, -// canonicalizes `begin`, `end`, and `strides. The calculation follows -// tf.StridedSlice op semantics. -static void CalculateSlicedShapeFromDenseIndices( - MutableArrayRef input_shape, int32_t begin_mask, int32_t end_mask, - int32_t shrink_axis_mask, MutableArrayRef begin, - MutableArrayRef end, MutableArrayRef stride) { - assert(input_shape.size() <= 32); // Only 32-bit masks are supported. - - // Make sure ranges' ranks are consistent with the input. - assert(input_shape.size() == begin.size()); - assert(input_shape.size() == end.size()); - assert(input_shape.size() == stride.size()); - - for (int i = 0, e = input_shape.size(); i < e; ++i) { - if (ShapedType::isDynamic(input_shape[i])) continue; - - int64_t dim_i = input_shape[i]; - int64_t begin_i = begin[i]; - int64_t end_i = end[i]; - int64_t stride_i = stride[i]; - - // [0]: mask for begin, [1]: mask for end - int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)}; - // [0]: bound for begin, [1]: bound for end - int64_t bounds[] = {stride_i > 0 ? 0 : -1, - stride_i > 0 ? dim_i : dim_i - 1}; - - // Canonicalizes the given range `point` (begin/end) according to the - // current dimension. `c` means case: 0 for begin, 1 for end. - auto canonicalize = [&](int64_t point, int c) { - if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1]; - - // Add dim as offset to negative range point. - point = point < 0 ? dim_i + point : point; - return Clamp(point, bounds[0], bounds[1]); - }; - - begin_i = canonicalize(begin_i, 0); - end_i = canonicalize(end_i, 1); - - int64_t interval_len = end_i - begin_i; - int64_t size_i = 0; - // If internal length is zero or has different sign from stride, it's a - // degenerated case: we are slicing nothing. Otherwise, calculate the sliced - // size. - if (interval_len != 0 && (interval_len < 0) == (stride_i < 0)) - size_i = (interval_len / stride_i) + (interval_len % stride_i != 0); - - begin[i] = begin_i; - if (IsSet(shrink_axis_mask, i)) { - // Shrink this dimension. It means we only take the element at begin_i. - input_shape[i] = 1; - end[i] = begin_i + 1; - stride[i] = 1; - } else { - input_shape[i] = size_i; - end[i] = end_i; - stride[i] = stride_i; - } - } -} - -// For the given `input_shape`, calculates the sliced shape using the given -// `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`, -// `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks. -// Updates the result back to `input_shape`. -static void CalculateSlicedShapeFromSparseIndices( - MutableArrayRef input_shape, ArrayRef sparse_begin, - ArrayRef sparse_end, ArrayRef sparse_strides, - int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, - int32_t new_axis_mask, int32_t shrink_axis_mask, - SmallVectorImpl *begin, SmallVectorImpl *end, - SmallVectorImpl *stride) { - int64_t num_sparse_indices = sparse_begin.size(); - SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask, - ellipsis_mask, new_axis_mask, shrink_axis_mask, - sparse_begin, sparse_end, sparse_strides}; - - // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is - // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields - // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...]. - if (sparse.ellipsis_mask == 0) { - Set(sparse.ellipsis_mask, sparse.dims); - sparse.dims++; - } - - int64_t dims = input_shape.size(); - DenseSliceSpec dense = {dims, - /*begin_mask = */ 0, - /*end_mask = */ 0, - /*shrink_axis_mask = */ 0, - *begin, - *end, - *stride}; - - BuildDenseSliceSpec(sparse, &dense); - CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask, - dense.end_mask, dense.shrink_axis_mask, - *begin, *end, *stride); -} - -bool StridedSliceOp::GetSlicedBoundRanges( - SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, - SmallVectorImpl *slice_stride) { - // TODO(hinsu): Support lowering for ops with dynamic begin and end values - // when it is possible to derive indices based on mask attributes. - DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; - if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(end(), m_Constant(&sparse_end_attr)) || - !matchPattern(strides(), m_Constant(&sparse_strides_attr))) - return false; - - auto input_ty = this->input().getType().dyn_cast(); - if (!input_ty || !input_ty.hasStaticShape()) return false; - auto input_shape = llvm::to_vector<4>(input_ty.getShape()); - - SmallVector sparse_begin, sparse_end, sparse_strides; - - for (const APInt &index : sparse_begin_attr) - sparse_begin.push_back(index.getSExtValue()); - for (const APInt &index : sparse_end_attr) - sparse_end.push_back(index.getSExtValue()); - for (const APInt &stride : sparse_strides_attr) - sparse_strides.push_back(stride.getSExtValue()); - - CalculateSlicedShapeFromSparseIndices( - input_shape, sparse_begin, sparse_end, sparse_strides, - begin_mask().getZExtValue(), end_mask().getZExtValue(), - ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), - shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); - return true; -} - -//===----------------------------------------------------------------------===// -// StridedSliceGradOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(StridedSliceGradOp op) { - auto shape_type = op.shape().getType().dyn_cast(); - if (shape_type && shape_type.getRank() != 1) - return op.emitOpError("'shape' operand must be 1D tensor, but got ") - << shape_type.getRank() << "D tensor"; - - if (failed(VerifyStridedSliceBase(op))) return failure(); - - // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with - // the sliced type from StridedSlice. - - return success(); -} - -bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( - SmallVectorImpl *input_shape, - SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, - SmallVectorImpl *slice_stride) { - DenseIntElementsAttr shape_attr; - DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; - if (!matchPattern(shape(), m_Constant(&shape_attr)) || - !matchPattern(begin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(end(), m_Constant(&sparse_end_attr)) || - !matchPattern(strides(), m_Constant(&sparse_strides_attr))) - return false; - - int rank = std::distance(shape_attr.begin(), shape_attr.end()); - - input_shape->clear(); - input_shape->reserve(rank); - for (const APInt &dim : shape_attr) - input_shape->push_back(dim.getSExtValue()); - - SmallVector sparse_begin, sparse_end, sparse_strides; - - for (const APInt &index : sparse_begin_attr) - sparse_begin.push_back(index.getSExtValue()); - for (const APInt &index : sparse_end_attr) - sparse_end.push_back(index.getSExtValue()); - for (const APInt &stride : sparse_strides_attr) - sparse_strides.push_back(stride.getSExtValue()); - - CalculateSlicedShapeFromSparseIndices( - *input_shape, sparse_begin, sparse_end, sparse_strides, - begin_mask().getZExtValue(), end_mask().getZExtValue(), - ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), - shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); - return true; -} - -//===----------------------------------------------------------------------===// -// TensorListReserveOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TensorListReserveOp op) { - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { - return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); - } - - if (!IsOfRankOrUnranked(op.num_elements(), 0)) { - return op.emitOpError("requires num_elements operand to be 0D tensor"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// TensorListElementShapeOp -//===----------------------------------------------------------------------===// - -OpFoldResult TensorListElementShapeOp::fold(ArrayRef operands) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - auto variant_type = - getElementTypeOrSelf(getOperand().getType()).cast(); - if (variant_type.getSubtypes().empty()) return {}; - return ConvertShapeToAttr(variant_type.getSubtypes()[0], width); -} - -//===----------------------------------------------------------------------===// -// TensorListStackOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TensorListStackOp op) { - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { - return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// TensorScatterUpdateOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TensorScatterUpdateOp op) { - if (!HasRankAtLeast(op.tensor(), 1)) - return op.emitOpError( - "requires tensor operand to have at least 1 dimension"); - if (!HasRankAtLeast(op.indices(), 1)) - return op.emitOpError( - "requires indices operand to have at least 1 dimension"); - if (!HasRankAtLeast(op.updates(), 1)) - return op.emitOpError( - "requires updates operand to have at least 1 dimension"); - - auto tensor_ty = op.tensor().getType().dyn_cast(); - auto indices_ty = op.indices().getType().dyn_cast(); - if (!tensor_ty || !indices_ty) return success(); - - int64_t num_index_dims = indices_ty.getShape().back(); - if (ShapedType::isDynamic(num_index_dims)) return success(); - - if (num_index_dims > tensor_ty.getRank()) - return op.emitOpError( - "requires tensor operand with rank greater than or equal to the " - "indices operand's last dimensions"); - return success(); -} - -//===----------------------------------------------------------------------===// -// TopKV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TopKV2Op op) { - if (!HasRankAtLeast(op.input(), 1)) - return op.emitOpError( - "requires input operand to have at least 1 dimension"); - - if (!IsOfRankOrUnranked(op.k(), 0)) - return op.emitOpError("requires k operand to be 0D tensor"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// ToBoolOp -//===----------------------------------------------------------------------===// - -namespace { -// If the input to ToBoolOp is a `tensor`, then the ToBoolOp is an identity -// function and can be removed. -class ToBoolOfZeroDBoolTensor : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ToBoolOp op, - PatternRewriter &rewriter) const override { - if (auto type = op.getOperand().getType().dyn_cast()) { - if (type.getRank() == 0 && type.getElementType().isInteger(1)) { - rewriter.replaceOp(op, op.getOperand()); - return success(); - } - } - return failure(); - } -}; -} // namespace - -void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// TransposeOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TransposeOp op) { - // TODO(hinsu): Verify using a custom verifier that, - // * Transpose permutation is 1-D of size equal to the rank of the first - // input, if the shapes are partially known. Requires use of a more - // restrictive type than TF_Tensor. - // * Result shape dimensions are possible based on the input shape. - return success(); -} - -// TODO(jpienaar): perm could be optional too. -void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, - Value perm) { - auto x_type = x.getType().cast(); - // If value is unranked, then so is results. - if (!x_type.hasRank()) - return TransposeOp::build(builder, result, - UnrankedTensorType::get(x_type.getElementType()), - x, perm); - - // TODO(jpienaar): Handle unknown perm case. - - // TODO(jpienaar): Extract utility function. - auto etype = x_type.cast().getElementType(); - DenseIntElementsAttr attr_shape; - if (matchPattern(perm, m_Constant(&attr_shape))) { - llvm::SmallVector const_shape; - if (attr_shape.isSplat()) { - const_shape.assign( - attr_shape.getNumElements(), - x_type.getDimSize((*attr_shape.begin()).getSExtValue())); - } else { - const_shape.reserve(attr_shape.getNumElements()); - for (const auto &dim : attr_shape) - const_shape.push_back(x_type.getDimSize(dim.getSExtValue())); - } - return TransposeOp::build( - builder, result, RankedTensorType::get(const_shape, etype), x, perm); - } - return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x, - perm); -} - -namespace { - -OpFoldResult FoldIdentityTranspose(TransposeOp op) { - auto const_perm = dyn_cast_or_null(op.perm().getDefiningOp()); - if (!const_perm) return {}; - - auto const_value = const_perm.value(); - const auto elements = const_value.getValues(); - - for (auto it : llvm::enumerate(elements)) { - if (it.index() != it.value()) return {}; - } - - // TODO(jpienaar): Remove if/when we handle this more generally. - if (op.getType() != op.x().getType()) { - // If the types don't match then only fold if all the operands are in the TF - // dialect. - for (auto user : op.getOperation()->getUsers()) - if (user->getDialect() != op.getDialect()) return {}; - } - - return op.x(); -} - -OpFoldResult FoldCancellableTranspose(TransposeOp op) { - // Operand is a TransposeOp. - auto transpose = dyn_cast_or_null(op.x().getDefiningOp()); - if (!transpose) return {}; - - // Permutations defined by constant operations. - auto perm0 = dyn_cast_or_null(op.perm().getDefiningOp()); - auto perm1 = dyn_cast_or_null(transpose.perm().getDefiningOp()); - if (!perm0 || !perm1) return {}; - - // With permutation indices that cancel each other - auto perm0_value = perm0.value().cast(); - auto perm1_value = perm1.value().cast(); - if (!AreCancellablePermutations(perm0_value, perm1_value)) return {}; - - return transpose.x(); -} - -} // namespace - -OpFoldResult TransposeOp::fold(ArrayRef operands) { - if (auto folded = FoldIdentityTranspose(*this)) return folded; - if (auto folded = FoldCancellableTranspose(*this)) return folded; - return {}; -} - -//===----------------------------------------------------------------------===// -// TruncateDivOp -//===----------------------------------------------------------------------===// - -void TruncateDivOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// UnpackOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(UnpackOp op) { - auto value_type = op.value().getType().dyn_cast(); - if (!value_type) return success(); - - int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis().getSExtValue(); - if (axis < -value_rank || axis >= value_rank) - return op.emitOpError("axis attribute must be in the range of [-") - << value_rank << ", " << value_rank << ')'; - - axis = GetDimForAxis(axis, value_rank); - int64_t dim_size = value_type.getDimSize(axis); - if (ShapedType::isDynamic(dim_size)) return success(); - - if (dim_size != op.getNumResults()) - return op.emitOpError("result count must be equal to ") << dim_size; - - return success(); -} - -//===----------------------------------------------------------------------===// -// Unsorted segment reduction ops -//===----------------------------------------------------------------------===// - -template -static LogicalResult VerifyUnsortedSegmentReduction(Op op) { - if (!HasRankAtMost(op.num_segments(), 0)) - return op.emitOpError("number of segments should be a 0-D tensor"); - - auto data_type = op.data().getType().template dyn_cast(); - auto segment_ids_type = - op.segment_ids().getType().template dyn_cast(); - if (data_type && segment_ids_type) { - if (data_type.getRank() < segment_ids_type.getRank()) - return op.emitOpError( - "requires segment ids rank to be less than or equal to data's rank"); - - int index = 0; - for (auto shape_pair : - llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) { - int64_t segment_id_dim = std::get<0>(shape_pair); - int64_t data_dim = std::get<1>(shape_pair); - if (!ShapedType::isDynamic(segment_id_dim) && - !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim) - return op.emitOpError( - "requires segment ids shape to be a prefix of data shape, " - "but dimension #") - << index << " differs: " << segment_id_dim << " vs. " - << data_dim; - ++index; - } - } - - DenseIntElementsAttr num_segments_attr; - if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) { - int64_t num_segments = (*num_segments_attr.begin()).getSExtValue(); - if (num_segments < 0) - return op.emitOpError("num of segments cannot be negative"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// VariableShapeOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(VariableShapeOp op) { - auto input_type = op.input().getType().cast(); - if (input_type.hasStaticShape() && input_type.getNumElements() != 1) - return op.emitOpError("requires input to have one resource"); - - auto resource_type = input_type.getElementType().cast(); - auto subtypes = resource_type.getSubtypes(); - switch (subtypes.size()) { - case 1: - return VerifyShapeOperandAndResult( - op, resource_type.getSubtypes().front(), op.getType()); - case 0: - return VerifyShapeOperandAndResult(op, Type(), op.getType()); - default: - return op.emitOpError( - "requires resource input type to have at most 1 subtype"); - } -} - -OpFoldResult VariableShapeOp::fold(ArrayRef operands) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - auto resource_type = - getElementTypeOrSelf(getOperand().getType()).cast(); - if (resource_type.getSubtypes().empty()) return {}; - return ConvertShapeToAttr(resource_type.getSubtypes()[0], width); -} - -//===----------------------------------------------------------------------===// -// WhileOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(WhileOp op) { - auto module = op.getParentOfType(); - auto cond_fn = module.lookupSymbol(op.cond()); - auto body_fn = module.lookupSymbol(op.body()); - if (!cond_fn) { - return op.emitOpError("cond refers to an undefined function : ") - << op.cond(); - } - if (!body_fn) { - return op.emitOpError("body refers to an undefined function : ") - << op.body(); - } - - auto cond_fn_type = cond_fn.getType(); - auto body_fn_type = body_fn.getType(); - - // Verify that the cond function has exactly one result. - if (cond_fn_type.getNumResults() != 1) - return op.emitOpError("requires cond function to have exactly one result"); - - SmallVector operands(op.getOperandTypes()); - - // Collect all the type lists for the op so that different pairs of type lists - // can be compared for the compatibility. - constexpr int kNumTypeLists = 5; - const std::array>, kNumTypeLists> - type_lists = {{ - {"operand", operands}, - {"body function result", body_fn_type.getResults()}, - {"result", op.getResultTypes()}, - {"cond function input", cond_fn_type.getInputs()}, - {"body function input", body_fn_type.getInputs()}, - }}; - - // A pair of type lists should be cast compatible with each other if one is - // converted to the another for a function call or assignment or there is a - // common source of inputs for both. Therefore, the While op requires the - // following pairs of type lists to be cast compatible for the tensor_cast - // operation: - // - // * Operands and cond inputs to call the cond function before the - // first iteration. - // * Operands and body inputs to call the body function for the first - // iteration if the cond functions returns True or equivalent result. - // * Operands and results to assign cond function arguments to op results if - // the cond function returns False or equivalent result. - // * All three pairs using cond inputs, body inputs and results as operand is - // a common source for all three. - // * Body result and cond inputs to call the cond function for the subsequent - // iterations. Similarly, Body result should be compatible with body inputs - // and op results. - // - // Note that the operands and body results need not be compatible as they are - // never converted from one to the another nor there is a common source - // tensors. Compatibility requirement is not transitive. - - for (int i = 0; i < kNumTypeLists; ++i) { - // Skip the first pair as the While op operands and body function results - // does not need to be compatible with each other. - for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { - auto &a = type_lists[i]; - auto &b = type_lists[j]; - - int a_size = a.second.size(); - if (a_size != b.second.size()) - return op.emitOpError( - llvm::formatv("requires the number of {0}s to be equal to the " - "number of {1}s. Found {2} and {3}, respectively", - a.first, b.first, a_size, b.second.size())); - - for (int idx = 0; idx < a_size; ++idx) { - auto a_type = a.second[idx]; - auto b_type = b.second[idx]; - - if (!AreCastCompatible({a_type, b_type})) - return op.emitError(llvm::formatv( - "{0} type {1} is incompatible with {2} type {3} at index {4}", - a.first, a_type, b.first, b_type, idx)); - } - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// WhileRegionOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(WhileRegionOp op) { - // Verify that the condition generates a single tensor result. - YieldOp yield = cast(op.cond().front().getTerminator()); - if (yield.getNumOperands() != 1) - return op.emitOpError() - << "condition should have a single tensor result"; - - auto cond_type = yield.getOperand(0).getType().dyn_cast(); - if (!cond_type || !cond_type.getShape().equals({}) || - !cond_type.getElementType().isInteger(/*width=*/1)) - return op.emitOpError() - << "condition should have a single tensor result"; - - // The body result types should match while op result types. - if (failed(VerifyRegionResults(op, op.body(), "body"))) return failure(); - - // Both condition and body should have same number and type of operands as - // the WhileRegion inputs. - const int num_inputs = op.getNumOperands(); - auto block_inputs_match_op_inputs = [&](Region ®ion, - StringRef name) -> LogicalResult { - Block &block = region.front(); - if (block.getNumArguments() != num_inputs) - return op.emitOpError() - << name << " should have same number of inputs (" << num_inputs - << ") as " << WhileRegionOp::getOperationName() << " but has " - << block.getNumArguments() << " inputs"; - - for (auto types_idx : llvm::enumerate( - llvm::zip(op.getOperandTypes(), block.getArgumentTypes()))) { - auto op_input_type = std::get<0>(types_idx.value()); - auto block_input_type = std::get<1>(types_idx.value()); - if (!AreCastCompatible({block_input_type, op_input_type})) - return op.emitOpError(llvm::formatv( - "{0} input type {1} is incompatible with {2} " - "input type {3} at index {4}", - name, block_input_type, WhileRegionOp::getOperationName(), - op_input_type, types_idx.index())); - } - return success(); - }; - - if (failed(block_inputs_match_op_inputs(op.cond(), "condition")) || - failed(block_inputs_match_op_inputs(op.body(), "body"))) - return failure(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// WhileRegionOp LoopLikeOpInterface -//===----------------------------------------------------------------------===// - -Region &WhileRegionOp::getLoopBody() { return body(); } - -bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) { - // If the Op defining the value exists and the defining op is outside the - // scope of this WhileRegion, then we can infer that its defined outside. - // The defining Op is outside the scope of this WhileRegion if this - // WhileRegionOp is not an ancestor of the defining op in the parent chain. - Operation *def_op = value.getDefiningOp(); - return def_op && !getOperation()->isAncestor(def_op); -} - -LogicalResult WhileRegionOp::moveOutOfLoop( - llvm::ArrayRef ops) { - // Move the hoisted value to just before the while. - Operation *while_op = this->getOperation(); - for (auto op : ops) op->moveBefore(while_op); - return success(); -} - -//===----------------------------------------------------------------------===// -// WhileRegionOp canonicalization -//===----------------------------------------------------------------------===// -namespace { -// Eliminate values that pass through the WhileRegionOp body. -struct WhileRegionEliminatePassThrough - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileRegionOp while_op, - PatternRewriter &rewriter) const override { - // Replace values that simply passthrough the body with extern values. The - // block arguments of body and while match and so the corresponding cond - // argument can be easily found. - int old_num_operands = while_op.getNumOperands(); - int new_num_operands = old_num_operands; - auto &body_block = while_op.body().front(); - auto &cond_block = while_op.cond().front(); - auto &yield = *body_block.getTerminator(); - - // Bit mask indicating which operands will be removed. - SmallVector removed_operand(old_num_operands, false); - - for (int op_idx : llvm::seq(0, old_num_operands)) { - auto body_arg = body_block.getArgument(op_idx); - if (body_arg == yield.getOperand(op_idx)) { - // Replace the use of the passthrough value with the while operand - // in the body and condition regions, as well as the while output (if - // type match) - // TODO(jurahul): Use PatternRewriter API for IR modification. - auto value = while_op.getOperand(op_idx); - if (body_arg.getType() == value.getType()) - body_arg.replaceAllUsesWith(value); - - auto cond_arg = cond_block.getArgument(op_idx); - if (cond_arg.getType() == value.getType()) - cond_arg.replaceAllUsesWith(value); - - auto result = while_op.getResult(op_idx); - if (result.getType() == value.getType()) - result.replaceAllUsesWith(value); - } - - // Now check if the operand is unused in both regions as well as the - // result. If so, mark it for removal. - if (body_block.getArgument(op_idx).use_empty() && - cond_block.getArgument(op_idx).use_empty() && - while_op.getResult(op_idx).use_empty()) { - removed_operand[op_idx] = true; - new_num_operands--; - } - } - - if (new_num_operands == old_num_operands) return failure(); - - // Compress the operands, region arguments, and outputs. - SmallVector new_while_operands; - SmallVector new_result_types; - new_while_operands.reserve(new_num_operands); - new_result_types.reserve(new_num_operands); - - // Build new operands and result type. - int next_idx = 0; - for (int op_idx : llvm::seq(0, old_num_operands)) { - if (removed_operand[op_idx]) continue; - new_while_operands.push_back(while_op.getOperand(op_idx)); - new_result_types.push_back(while_op.getResult(op_idx).getType()); - next_idx++; - } - - // Create the new while operation. - auto new_while_op = - rewriter.create(while_op.getLoc(), new_result_types, - new_while_operands, while_op.getAttrs()); - - // Move region bodies to the new while. - rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(), - new_while_op.cond().end()); - rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(), - new_while_op.body().end()); - - auto &new_cond_block = new_while_op.cond().front(); - auto &new_body_block = new_while_op.body().front(); - auto &new_yield = *new_body_block.getTerminator(); - - // Build a vector of new results. Also patch up the region bodies and yield. - SmallVector new_results; - next_idx = 0; - for (int op_idx : llvm::seq(0, old_num_operands)) { - if (removed_operand[op_idx]) { - new_cond_block.eraseArgument(next_idx); - new_body_block.eraseArgument(next_idx); - new_yield.eraseOperand(next_idx); - new_results.push_back(nullptr); - } else { - new_results.push_back(new_while_op.getResult(next_idx++)); - } - } - - rewriter.replaceOp(while_op, new_results); - return success(); - } -}; - -} // anonymous namespace - -void WhileRegionOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// XdivyOp -//===----------------------------------------------------------------------===// - -void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// TableGen'd op method definitions -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc" - //===----------------------------------------------------------------------===// // TF Dialect Interfaces //===----------------------------------------------------------------------===// @@ -4601,8 +124,6 @@ struct TFInlinerInterface : public DialectInlinerInterface { // TF Dialect //===----------------------------------------------------------------------===// -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc" - std::vector *TensorFlowDialect::additional_operation_hooks_ = new std::vector(); @@ -4611,7 +132,7 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) : Dialect(/*name=*/"tf", context) { addOperations< #define GET_OP_LIST -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc" >(); addTypes< #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index f37b71575f6..d06dce81e09 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -35,6 +35,9 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -112,17 +115,6 @@ class TensorFlowDialect : public Dialect { static std::vector *additional_operation_hooks_; }; -// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose -// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use -// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and -// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to -// undefine here to avoid expanding the getter symbol as macro when including -// both mutex.h and this header file. -#undef mutex_lock - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" - } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc new file mode 100644 index 00000000000..af7a16ba127 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -0,0 +1,1807 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace mlir { +namespace TF { + +namespace { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" +} // namespace + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// AddNOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddNOp::fold(ArrayRef operands) { + if (operands.size() == 1) return *inputs().begin(); + return {}; +} + +//===----------------------------------------------------------------------===// +// AddV2Op +//===----------------------------------------------------------------------===// + +void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult AddV2Op::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// AllOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllOp op) { + return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), + op.getLoc()); +} + +//===----------------------------------------------------------------------===// +// AnyOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AnyOp op) { + return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), + op.getLoc()); +} + +//===----------------------------------------------------------------------===// +// AssertOp +//===----------------------------------------------------------------------===// + +namespace { + +// Removes Assert with constant true predicate. +struct AssertWithTrue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssertOp op, + PatternRewriter &rewriter) const override { + ElementsAttr cst; + if (matchPattern(op.condition(), m_Constant(&cst))) { + if (cst.getValue({}).getValue()) { + rewriter.eraseOp(op); + return success(); + } + } + return failure(); + } +}; +} // namespace + +void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BatchMatMulOp +//===----------------------------------------------------------------------===// + +void BatchMatMulOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BatchMatMulV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BatchMatMulV2Op op) { + if (!HasRankAtLeast(op.x(), 2)) { + return op.emitOpError("requires lhs operand to have rank at least two"); + } + if (!HasRankAtLeast(op.y(), 2)) { + return op.emitOpError("requires rhs operand to have rank at least two"); + } + return success(); +} + +void BatchMatMulV2Op::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BatchToSpaceOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BatchToSpaceOp op) { + // Op already has a constraint that block_size >= 2. + int64_t block_size = op.block_size().getSExtValue(); + + llvm::SmallVector input_shape(4, ShapedType::kDynamicSize); + auto input_type = op.input().getType().cast(); + if (input_type.hasRank()) { + if (input_type.getRank() != 4) + return op.emitOpError() + << "requires input to be a 4D tensor, but got " << input_type; + + int64_t input_batch = input_type.getDimSize(0); + if (input_batch != ShapedType::kDynamicSize && + input_batch % (block_size * block_size) != 0) { + return op.emitOpError() + << "requires input batch (dimension 0) to be evenly divisible " + "by (block_size * block_size), but got input batch " + << input_batch << " and block_size " << block_size; + } + + input_shape.assign(input_type.getShape().begin(), + input_type.getShape().end()); + } + + auto crops_type = op.crops().getType().cast(); + if (crops_type.hasRank()) { + if (crops_type.getRank() != 2) + return op.emitOpError() + << "requires crops to be a 2D tensor, but got " << crops_type; + + auto dim_of_size = [&](int64_t dim, int64_t size) { + if (crops_type.isDynamicDim(dim)) return true; + return crops_type.getDimSize(dim) == size; + }; + if (!dim_of_size(0, 2) || !dim_of_size(1, 2)) + return op.emitOpError() + << "requires crops to be a tensor<2x2>, but got " << crops_type; + } + + DenseIntElementsAttr crops_attr; + // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]], + // and flattened as [crop_top, crop_bottom, crop_left, crop_right] + llvm::SmallVector crops_values; + if (matchPattern(op.crops(), m_Constant(&crops_attr))) { + assert(crops_attr.getNumElements() == 4 && + "tf.BatchToSpace crops must have 4 elements"); + + auto crops_range = crops_attr.getIntValues(); + for (const auto &crops_value : crops_range) { + int64_t crops_value_int = crops_value.getSExtValue(); + if (crops_value_int < 0) + return op.emitOpError() + << "requires all crop values to be nonnegative, but got " + << crops_attr; + + crops_values.push_back(crops_value_int); + } + } + + auto output_type = op.output().getType().cast(); + if (output_type.hasRank()) { + if (output_type.getRank() != 4) + return op.emitOpError() + << "requires output to be a 4D tensor, but got " << output_type; + + auto static_dims = [](int64_t dim_a, int64_t dim_b) { + return dim_a != ShapedType::kDynamicSize && + dim_b != ShapedType::kDynamicSize; + }; + + auto output_shape = output_type.getShape(); + + // output batch = input batch / (block_size * block_size). + int64_t input_batch = input_shape[0]; + int64_t output_batch = output_shape[0]; + if (static_dims(input_batch, output_batch) && + (output_batch * block_size * block_size) != input_batch) + return op.emitOpError() + << "requires output batch (dimension 0) to be equal to input " + "batch (dimension 0) / (block_size * block_size), but got " + "output batch " + << output_batch << ", input batch " << input_batch + << ", and block_size " << block_size; + + auto check_spatial_dim = [&](int64_t spatial_dim_index, + llvm::StringRef dim_name, + llvm::StringRef crop_a_name, + llvm::StringRef crop_b_name) -> LogicalResult { + int64_t input_dim = input_shape[spatial_dim_index]; + int64_t output_dim = output_shape[spatial_dim_index]; + if (!static_dims(input_dim, output_dim)) return success(); + + int64_t input_dim_pad = input_dim * block_size; + // If crops are unknown, the maximum output spatial dim size is input + // spatial dim size * block_size, as crops can be minimum 0. + if (crops_values.empty() && output_dim > input_dim * block_size) + return op.emitOpError() + << "requires output " << dim_name << " (dimension " + << spatial_dim_index << ") to be less than or equal to input " + << dim_name << " (dimension " << spatial_dim_index + << ") * block_size, but got output " << dim_name << " " + << output_dim << ", input " << dim_name << " " << input_dim + << ", and block_size " << block_size; + + if (!crops_values.empty()) { + // output spatial dim = input spatial dim * block_size - crops. + int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)]; + int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1]; + if (output_dim != input_dim_pad - crop_a - crop_b) + return op.emitOpError() + << "requires output " << dim_name << " (dimension " + << spatial_dim_index << ") to be equal to input " << dim_name + << " (dimension " << spatial_dim_index << ") * block_size - " + << crop_a_name << " - " << crop_b_name << ", but got output " + << dim_name << " " << output_dim << ", input " << dim_name + << " " << input_dim << ", " << crop_a_name << " " << crop_a + << ", " << crop_b_name << " " << crop_b << ", and block_size " + << block_size; + } + + return success(); + }; + + if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) || + failed(check_spatial_dim(2, "width", "crop_left", "crop_right"))) + return failure(); + + int64_t input_depth = input_shape[3]; + int64_t output_depth = output_shape[3]; + if (static_dims(input_depth, output_depth) && output_depth != input_depth) + return op.emitOpError() + << "requires output depth (dimension 3) to be equal to input " + "depth (dimension 3), but got output depth " + << output_depth << " and input depth " << input_depth; + } + + return success(); +} + +void BatchToSpaceOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BiasAddOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// * the value and bias operands have valid ranks or are unranked. +// * Channel dimension of the value operand and length of bias matches if they +// are not unknown. +// +static LogicalResult Verify(BiasAddOp op) { + StringRef format = op.data_format(); + if (format == "NHWC") { + if (!HasRankAtLeast(op.value(), 2)) + return op.emitOpError( + "requires value operand to have rank at least two with `NHWC` data " + "format"); + } else { + // Op definition requires data_format to be either NHWC or NCHW. + DCHECK_EQ(format.str(), "NCHW"); + if (!HasRankAtLeast(op.value(), 3)) + return op.emitOpError( + "requires value operand to have rank at least three with `NCHW` data " + "format"); + } + + if (!IsOfRankOrUnranked(op.bias(), 1)) + return op.emitOpError("requires bias operand to have rank exactly one"); + + RankedTensorType value_ty = op.value().getType().dyn_cast(); + RankedTensorType bias_ty = op.bias().getType().dyn_cast(); + if (!bias_ty || !value_ty) return success(); + + // TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute + // dimension indices based on format. + int64_t feature_dim_idx = format == "NHWC" ? value_ty.getRank() - 1 : 1; + int64_t feature_dim = value_ty.getDimSize(feature_dim_idx); + int64_t bias_len = bias_ty.getDimSize(0); + if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) { + return op.emitOpError() + << "requires channel dimension and feature dimension to match; " + "found " + << feature_dim << " and " << bias_len << ", respectively"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// BiasAddGradOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// * the out_backprop operands have valid ranks or are unranked. +// +static LogicalResult Verify(BiasAddGradOp op) { + StringRef format = op.data_format(); + if (format == "NHWC") { + if (!HasRankAtLeast(op.out_backprop(), 2)) + return op.emitOpError( + "requires out_backprop operand to have rank at least two with `NHWC` " + "data format"); + } else { + // Op definition requires data_format to be either NHWC or NCHW. + DCHECK_EQ(format.str(), "NCHW"); + if (!HasRankAtLeast(op.out_backprop(), 3)) + return op.emitOpError( + "requires out_backprop operand to have rank at least three with " + "`NCHW` data format"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// BiasAddV1Op +//===----------------------------------------------------------------------===// + +void BiasAddV1Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BroadcastToOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BroadcastToOp op) { + // TODO(antiagainst): check that + // * The 'shape' input is an 1-D int tensor. + // * Each dimension pair of the source and target shapes are either equal + // or one of them is one. + return success(); +} + +//===----------------------------------------------------------------------===// +// CaseOp +//===----------------------------------------------------------------------===// + +class FoldConstantCaseOp : public OpRewritePattern { + public: + explicit FoldConstantCaseOp(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TF::CaseOp op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult FoldConstantCaseOp::matchAndRewrite( + TF::CaseOp op, PatternRewriter &rewriter) const { + // Extract the constant cond value. + DenseIntElementsAttr branch; + if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); + + // Only attempt to fold scalar valued case statements. + // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. + if (!branch.getType().cast().getShape().empty()) + return failure(); + + int index = *branch.getValues().begin(); + // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. + if (index >= op.branches().size()) return failure(); + + auto func = op.branches()[index].cast(); + auto empty = rewriter.getStringAttr(""); + auto call_op = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, + /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); + PropagateAttributes(op.getOperation(), call_op); + rewriter.replaceOp(op, call_op.getResults()); + return success(); +} + +void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +OpFoldResult CastOp::fold(ArrayRef operands) { + // Cast with the same type is a no-op. + Value operand = getOperand(); + if (getType() == operand.getType()) return operand; + return {}; +} + +//===----------------------------------------------------------------------===// +// ConcatOp and ConcatV2Op +//===----------------------------------------------------------------------===// + +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { + // TODO(hinsu): Convert variadic length attributes to derived attributes. + Operation::operand_range values = op.values(); + + int axis_idx = std::is_same() ? 0 : 1; + Value axis = *op.getODSOperands(axis_idx).begin(); + if (!HasRankAtMost(axis, 1)) { + return op.emitOpError( + "requires axis to be of scalar type (or vector type for older " + "versions)"); + } + + return VerifyTypesCompatibility(values, + /*mask_one_dim=*/true, op.getOperation()); +} + +void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// ConcatOffsetOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ConcatOffsetOp op) { + if (op.N() < 2) + return op.emitOpError() << "requires N to be at least 2, got " << op.N(); + + if (op.shape().size() != op.offset().size()) + return op.emitOpError() + << "requires sizes of shapes and offsets to be the same, got sizes " + << op.shape().size() << " and " << op.offset().size(); + + auto ranked_dim = op.concat_dim().getType().dyn_cast(); + if (ranked_dim && ranked_dim.getRank() != 0) + return op.emitOpError() + << "requires concat_dim to be a scalar, got tensor of rank " + << ranked_dim.getRank(); + + int64_t num_dims = -1; + for (auto shape_offset_idx : + llvm::enumerate(llvm::zip(op.shape(), op.offset()))) { + Value shape = std::get<0>(shape_offset_idx.value()); + Value offset = std::get<1>(shape_offset_idx.value()); + const size_t idx = shape_offset_idx.index(); + + if (failed(verifyCompatibleShape(shape.getType(), offset.getType()))) + return op.emitOpError() << "requires operand and result " << idx + << " to have compatible shapes"; + + auto ranked_shape = shape.getType().dyn_cast(); + if (!ranked_shape) continue; + + if (ranked_shape.getRank() != 1) + return op.emitOpError() << "requires shape tensor operand " << idx + << " to be of rank 1, got tensor of rank " + << ranked_shape.getRank(); + + if (!ranked_shape.hasStaticShape()) continue; + + int64_t ranked_shape_dim = ranked_shape.getDimSize(0); + if (num_dims == -1) + num_dims = ranked_shape_dim; + else if (ranked_shape_dim != num_dims) + return op.emitOpError() + << "requires shape tensor (rank 1) operand " << idx + << " to be of length " << num_dims + << ", got tensor (rank 1) of length " << ranked_shape_dim; + } + + return success(); +} + +LogicalResult ConcatOffsetOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + // ConcatOffset must have its first operand be concat_dim and at least two + // shape tensors in variadic shapes operand. + if (operands.size() < 3) return failure(); + + // Check concat_dim is a scalar. + auto concat_dim_attr = operands[0].dyn_cast_or_null(); + if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0) + return failure(); + + llvm::SmallVector shapes; + shapes.reserve(operands.size() - 1); + for (Attribute shape : llvm::drop_begin(operands, 1)) + if (auto shape_attr = shape.dyn_cast_or_null()) + shapes.push_back(shape_attr); + else + return failure(); + + // Check all shapes are vectors of the same length. + if (shapes.front().getType().getRank() != 1) return success(); + const int64_t num_dims = shapes.front().getNumElements(); + for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) + if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims) + return failure(); + + // Check concat_dim is within [-num_dims, num_dims). + int32_t concat_dim = (*concat_dim_attr.getValues().begin()); + if (concat_dim < 0) concat_dim += num_dims; + if (concat_dim >= num_dims || concat_dim < 0) return failure(); + + // Check all elements besides at concat_dim match across all shape tensors. + SmallVector shape0; + shape0.reserve(num_dims); + for (int32_t dim : shapes.front().getValues()) shape0.push_back(dim); + + for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) { + for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) { + if (dims_and_idx.index() == concat_dim) continue; + + if (std::get<0>(dims_and_idx.value()) != + std::get<1>(dims_and_idx.value()).getSExtValue()) + return failure(); + } + } + + // Compute an exclusive cumulative sum of elements at concat_dim. + results.reserve(shapes.size()); + SmallVector cumulative_sum(num_dims, 0); + RankedTensorType offset_type = + RankedTensorType::get({num_dims}, IntegerType::get(32, getContext())); + for (DenseIntElementsAttr shape : shapes) { + results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); + cumulative_sum[concat_dim] += shape.getValue(concat_dim); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ConjOp +//===----------------------------------------------------------------------===// + +void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + + // Return the held attribute value. + return value(); +} + +// Builds a constant op with the specified attribute `value`. The result +// op's type is deduced from `value`; if `value` is of scalar type, +// wraps it up with a tensor type of empty shape. +// TODO(jpienaar): This one differs from the autogenerated one as it takes an +// attribute but always creates an ElementsAttr internally. +void ConstOp::build(OpBuilder &builder, OperationState &result, + Attribute value) { + ShapedType type; + if (auto elem_attr = value.dyn_cast()) { + return ConstOp::build(builder, result, elem_attr); + } else if (value.isa()) { + // All TensorFlow types must be tensor types. In the build() method, + // we want to provide more flexibility by allowing attributes of scalar + // types. But we need to wrap it up with ElementsAttr to construct + // valid TensorFlow constants. + type = RankedTensorType::get(/*shape=*/{}, value.getType()); + return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); + } + // TODO(jpienaar): support other TensorFlow specific types. + llvm_unreachable("unsupported attribute type for building tf.Const"); +} + +void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, + Attribute value) { + // Handle the case where the type and value are already tensors. + if (type.isa() && value.isa()) { + result.addTypes(type); + result.addAttribute("value", value); + return; + } + + // Otherwise, default to the attribute builder. + ConstOp::build(builder, result, value); + assert(type == result.types[0] && "type mismatch in construction"); +} + +LogicalResult ConstOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto value = attributes.get("value"); + if (!value) return emitOptionalError(location, "missing attribute 'value'"); + if (auto elem_attr = value.dyn_cast()) { + inferredReturnTypes.assign({elem_attr.getType()}); + return success(); + } + return emitOptionalError(location, + "attribute 'value' failed to satisfy constraint: " + "constant vector/tensor"); +} + +//===----------------------------------------------------------------------===// +// Conv2DOp and Conv3DOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult VerifyConvOpAttributes(OpT op, int num_dims) { + if (!IsOfRankOrUnranked(op.getResult(), num_dims)) + return op.emitOpError() + << "requires result to be " << num_dims << "D tensor"; + + auto is_not_positive = [](Attribute val) { + return val.cast().getValue().getSExtValue() <= 0; + }; + + int64_t strides_size = op.strides().size(); + if (strides_size != num_dims) + return op.emitOpError() << "requires strides attribute length to be " + << num_dims << "; actual length " << strides_size; + if (llvm::any_of(op.strides().getValue(), is_not_positive)) + return op.emitOpError("requires positive strides"); + + int64_t dilations_size = op.strides().size(); + if (op.dilations().size() != num_dims) + return op.emitOpError() << "requires dilations attribute length to be " + << num_dims << "; actual length " << dilations_size; + if (llvm::any_of(op.dilations().getValue(), is_not_positive)) + return op.emitOpError("requires positive dilations"); + + return success(); +} + +// Verifies that, +// * Ranks of operands and result are valid +// * Number of input channels is divisible by the number of filter input +// channels +// * Length of explicit_paddings attribute is valid and has non negative +// elements +// * strides and dilations attributes have positive elements +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { + int num_spatial_dims = std::is_same() ? 2 : 3; + int num_dims = 2 + num_spatial_dims; + + if (!IsOfRankOrUnranked(op.input(), num_dims) || + !IsOfRankOrUnranked(op.filter(), num_dims)) + return op.emitOpError() + << "requires operands to be " << num_dims << "D tensor"; + + // EXPLICIT padding mode and the associated attribute is limited to Conv2D. + // So, fetch attribute by string instead of the op.explicit_paddings() + // attribute getter. + if (op.padding() == "EXPLICIT") { + auto paddings = op.template getAttrOfType("explicit_paddings"); + if (!paddings) + return op.emitOpError() << "requires attribute 'explicit_paddings' with " + "'EXPLICIT' padding mode"; + + int64_t paddings_size = paddings.size(); + int64_t expected_size = 2 * num_dims; + + if (paddings_size != expected_size) + return op.emitOpError() + << "requires explicit_paddings attribute length to be " + << expected_size << "; actual length " << paddings_size; + + auto is_negative = [](Attribute val) { + return val.cast().getValue().getSExtValue() < 0; + }; + if (llvm::any_of(paddings.getValue(), is_negative)) + return op.emitOpError("requires non negative explicit paddings"); + } + + LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims); + if (failed(verify_result)) { + return verify_result; + } + + int64_t input_channels = -1; + if (auto ty = op.input().getType().template dyn_cast()) { + std::string data_format = op.data_format().str(); + tensorflow::TensorFormat format; + auto is_valid = FormatFromString(data_format, &format); + DCHECK(is_valid) << data_format; + int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format); + input_channels = ty.getDimSize(idx); + } + + int64_t filter_channels = -1; + if (auto ty = op.filter().getType().template dyn_cast()) { + int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( + num_dims, tensorflow::FORMAT_HWIO); + filter_channels = ty.getDimSize(idx); + } + + if (input_channels != -1 && filter_channels != -1 && + input_channels % filter_channels != 0) + return op.emitOpError() + << "requires the number of input channels to be divisible by the " + "number of filter input channels; found " + << input_channels << " and " << filter_channels << ", respectively"; + + return success(); +} + +LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { + auto perm = GetDataFormatPermutation(this->data_format(), data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + return success(); +} + +StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Input must be a tensor. + auto input_ty = input().getType().dyn_cast(); + if (!input_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = input_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For f32/f16 data type decision depends on the filter size in spatial + // dimensions, for other data types we keep current data format. + if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16()) + return data_format(); + + // Keep current data format if filter rank is unknown or not equal to 4. + auto filter_ty = filter().getType().dyn_cast(); + if (!filter_ty || filter_ty.getRank() != 4) return data_format(); + + const int64_t d0 = filter_ty.getDimSize(0); + const int64_t d1 = filter_ty.getDimSize(1); + + auto all_ones = [](ArrayAttr arr) -> bool { + return llvm::all_of(arr, [](Attribute attr) -> bool { + return attr.cast().getInt() == 1; + }); + }; + + // Convolutions with 1x1 filter and with strides and dilations all ones, can + // be computed as a GEMM in NHWC data format, and can be up to ~2x times + // faster than convolution in NCHW. + const bool one_by_one = d0 == 1 && d1 == 1; + const bool trivial_strides = all_ones(strides()); + const bool trivial_dilations = all_ones(dilations()); + + // TODO(ezhulenev): This might lead to excessive transposes in the final IR, + // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1. + // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all + // users can efficiently use NHWC data format? + if (one_by_one && trivial_strides && trivial_dilations) { + return "NHWC"; + } + + // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because + // it's the fastest option on NVIDIA GPUs with cuDNN library support. + return "NCHW"; +} + +//===----------------------------------------------------------------------===// +// Conv2dBackpropFilterOp +//===----------------------------------------------------------------------===// + +LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { + StringRef src_data_format = this->data_format(); + + auto perm = GetDataFormatPermutation(src_data_format, data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + // Permute filter sizes operand. + OpBuilder builder(getOperation()); + auto filter_sizes_permuted = builder.create( + getLoc(), filter_sizes(), StringAttr::get(src_data_format, getContext()), + StringAttr::get(data_format, getContext())); + setOperand(1, filter_sizes_permuted); + + return success(); +} + +StringRef Conv2DBackpropFilterOp::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Input must be a tensor. + auto input_ty = input().getType().dyn_cast(); + if (!input_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = input_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // Otherwise always use "NCHW". + return "NCHW"; +} + +//===----------------------------------------------------------------------===// +// Conv2DBackpropInputOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(Conv2DBackpropInputOp op) { + int num_spatial_dims = 2; + int num_dims = 2 + num_spatial_dims; + + if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) || + !IsOfRankOrUnranked(op.filter(), num_dims)) + return op.emitOpError() + << "requires operands to be " << num_dims << "D tensor"; + + LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims); + if (failed(verify_result)) { + return verify_result; + } + + return success(); +} + +LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { + StringRef src_data_format = this->data_format(); + + auto perm = GetDataFormatPermutation(src_data_format, data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + // Permute input sizes operand. + OpBuilder builder(getOperation()); + auto input_sizes_permuted = builder.create( + getLoc(), input_sizes(), StringAttr::get(src_data_format, getContext()), + StringAttr::get(data_format, getContext())); + setOperand(0, input_sizes_permuted); + + return success(); +} + +StringRef Conv2DBackpropInputOp::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Filter must be a tensor. + auto filter_ty = filter().getType().dyn_cast(); + if (!filter_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = filter_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // Otherwise always use "NCHW". + return "NCHW"; +} + +//===----------------------------------------------------------------------===// +// DataFormatVecPermuteOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DataFormatVecPermuteOp op) { + auto input_ty = op.x().getType().dyn_cast(); + if (!input_ty) return success(); + + int rank = input_ty.getRank(); + if (rank != 1 && rank != 2) + return op.emitOpError("requires input of rank 1 or 2"); + + if (rank == 1) { + int64_t dim0 = input_ty.getDimSize(0); + if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) + return op.emitOpError("requires 1D input of size 4 or size 2"); + } + + if (rank == 2) { + int64_t dim0 = input_ty.getDimSize(0); + if (dim0 != ShapedType::kDynamicSize && dim0 != 4) + return op.emitOpError( + "requires first dimensions of 2D input to be of size 4"); + + int64_t dim1 = input_ty.getDimSize(1); + if (dim1 != ShapedType::kDynamicSize && dim1 != 2) + return op.emitOpError( + "requires second dimensions of 2D input to be of size 2"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// DivOp +//===----------------------------------------------------------------------===// + +void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult DivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// DynamicStitchOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicStitchOp op) { + if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1"); + + if (RankedTensorType out_ty = op.getType().dyn_cast()) { + if (out_ty.getRank() == 0) { + return op.emitOpError("requires non scalar output"); + } + } + + llvm::SmallDenseSet index_values; + bool all_indices_const = true; + int32_t max_index = -1; + llvm::Optional> inferred_item_shape; + for (auto it : llvm::zip(op.indices(), op.data())) { + Value index = std::get<0>(it); + + DenseIntElementsAttr index_attr; + if (matchPattern(index, m_Constant(&index_attr))) { + for (int32_t index : index_attr.getValues()) { + if (index < 0) + return op.emitOpError() + << "requires non-negative index values; found " << index; + max_index = std::max(index, max_index); + index_values.insert(index); + } + } else { + all_indices_const = false; + } + + Value data = std::get<1>(it); + RankedTensorType index_ty = index.getType().dyn_cast(); + RankedTensorType data_ty = data.getType().dyn_cast(); + if (!index_ty || !data_ty) continue; + + int64_t index_rank = index_ty.getRank(); + ArrayRef data_shape = data_ty.getShape(); + ArrayRef index_shape = index_ty.getShape(); + if (failed(mlir::verifyCompatibleShape(index_shape, + data_shape.take_front(index_rank)))) + return op.emitOpError() << "requires shape of data with type " << data_ty + << " to have prefix matching with shape of the " + "corresponding index type " + << index_ty; + + ArrayRef item_shape = data_shape.drop_front(index_rank); + if (!inferred_item_shape) { + inferred_item_shape = llvm::to_vector<4>(item_shape); + continue; + } + + if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape))) + return op.emitOpError() << "has inconsistent shaped data and index " + "pairs; inferred item shapes [" + << llvm::makeArrayRef(*inferred_item_shape) + << "] and [" << item_shape << "] don't match"; + for (int i = 0, e = item_shape.size(); i < e; ++i) { + int64_t &inferred_dim = (*inferred_item_shape)[i]; + int64_t dim = item_shape[i]; + if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim; + } + } + + // If all indices are constants, then verify that they cover all indices in + // the range [0, max_index] and the output type is legal. + if (all_indices_const) { + for (int32_t i = 0; i <= max_index; i++) { + if (!index_values.count(i)) + return op.emitOpError() << "missing index " << i; + } + + if (inferred_item_shape) { + SmallVector expected_shape; + expected_shape.push_back(max_index + 1); + expected_shape.append(inferred_item_shape->begin(), + inferred_item_shape->end()); + + auto out_ty = op.getType().cast(); + auto expected_out_ty = + RankedTensorType::get(expected_shape, out_ty.getElementType()); + + if (!AreCastCompatible({out_ty, expected_out_ty})) { + return op.emitOpError() << "has invalid output type; should be " + "compatible with inferred type " + << expected_out_ty; + } + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// EinsumOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// * Arity of the op is at most two. +// +// TODO(hinsu): Verify einsum equation attribute. +static LogicalResult Verify(EinsumOp op) { + if (op.N() > 2) { + return op.emitOpError("supports at most two operands"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// EmptyOp +//===----------------------------------------------------------------------===// + +OpFoldResult EmptyOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "empty op has one operand"); + + Attribute attr = operands.front(); + if (!attr) return {}; + + auto int_attr = attr.cast(); + SmallVector out_shape; + for (const auto val : int_attr.getValues()) { + out_shape.push_back(val); + } + + auto type = getResult().getType().cast(); + auto etype = type.getElementType(); + + // We can not fold if the result is not static. + if (!type.hasStaticShape()) return {}; + + if (auto float_type = etype.dyn_cast()) { + auto out_type = RankedTensorType::get(out_shape, float_type); + return DenseElementsAttr::get(out_type, + {APFloat(float_type.getFloatSemantics())}); + } + + if (auto int_type = etype.dyn_cast()) { + auto out_type = RankedTensorType::get(out_shape, etype); + APInt val(int_type.getWidth(), 0, int_type.getSignedness()); + return DenseElementsAttr::get(out_type, val); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// EmptyTensorListOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(EmptyTensorListOp op) { + if (!IsOfRankOrUnranked(op.element_shape(), 0) && + !IsOfRankOrUnranked(op.element_shape(), 1)) { + return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); + } + + if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) { + return op.emitOpError("requires max_num_elements operand to be 0D tensor"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// EqualOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(EqualOp op) { + // If we allow inputs to have incompatible type, then nothing to do. + if (!op.incompatible_shape_error()) return success(); + + // Otherwise, check inputs are broadcastable. + return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( + op.getOperation()); +} + +void EqualOp::build(OpBuilder &builder, OperationState &result, Value x, + Value y, BoolAttr incompatible_shape_error) { + auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, + incompatible_shape_error); + return build(builder, result, result_type, x, y, incompatible_shape_error); +} + +//===----------------------------------------------------------------------===// +// ExpandDimsOp +//===----------------------------------------------------------------------===// + +Type InferExpandDimsOpType(Value input, Value dim) { + Type element_ty = input.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + + auto input_ty = input.getType().dyn_cast(); + if (!input_ty) return unranked_ty; + + DenseIntElementsAttr dim_attr; + if (!matchPattern(dim, m_Constant(&dim_attr)) || + dim_attr.getNumElements() != 1) + return unranked_ty; + int64_t dim_val = (*dim_attr.begin()).getSExtValue(); + int64_t input_rank = input_ty.getRank(); + + if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty; + if (dim_val < 0) dim_val += input_rank + 1; + + SmallVector shape = llvm::to_vector<4>(input_ty.getShape()); + shape.insert(shape.begin() + dim_val, 1); + return RankedTensorType::get(shape, element_ty); +} + +void ExpandDimsOp::build(OpBuilder &builder, OperationState &result, + Value input, Value dim) { + return build(builder, result, InferExpandDimsOpType(input, dim), input, dim); +} + +//===----------------------------------------------------------------------===// +// FakeQuantWithMinMaxArgsOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) { + // TODO(fengliuai): moving the following to an utility method. + const llvm::fltSemantics &semantics = op.min().getSemantics(); + float rmin, rmax; + if (&semantics == &APFloat::IEEEsingle()) { + rmin = op.min().convertToFloat(); + rmax = op.max().convertToFloat(); + } else { + rmin = op.min().convertToDouble(); + rmax = op.max().convertToDouble(); + } + // Range boundaries must be valid. + if (rmin >= rmax) { + return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + + "," + Twine(std::to_string(rmax)) + "]"); + } + int64_t num_bits = op.num_bits().getSExtValue(); + if (num_bits < 2 || num_bits > 16) { + return op.emitOpError( + "requires num_bits to be between 2 and 16, inclusive"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// FakeQuantWithMinMaxVarsOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { + auto min = GetRankedTensorTypeForOperand(op.min()); + if (min && !IsOfRankedFloatTensorType(min, 0)) + return op.emitOpError("requires min to be a 0d float tensor"); + + auto max = GetRankedTensorTypeForOperand(op.max()); + if (max && !IsOfRankedFloatTensorType(max, 0)) + return op.emitOpError("requires max to be a 0d float tensor"); + + int64_t num_bits = op.num_bits().getSExtValue(); + if (num_bits < 2 || num_bits > 16) { + return op.emitOpError( + "requires num_bits to be between 2 and 16, inclusive"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// FakeQuantWithMinMaxVarsPerChannelOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { + auto min = GetRankedTensorTypeForOperand(op.min()); + if (min && !IsOfRankedFloatTensorType(min, 1)) + return op.emitOpError("requires min to be a 1d float tensor"); + + auto max = GetRankedTensorTypeForOperand(op.max()); + if (max && !IsOfRankedFloatTensorType(max, 1)) + return op.emitOpError("requires max to be a 1d float tensor"); + + Value inputs = op.inputs(); + if (!HasRankAtLeast(inputs, 1)) + return op.emitError("requires inputs to be at least 1d float tensor"); + + int64_t num_bits = op.num_bits().getSExtValue(); + if (num_bits < 2 || num_bits > 16) { + return op.emitOpError( + "requires num_bits to be between 2 and 16, inclusive"); + } + + auto inputs_type = inputs.getType().dyn_cast(); + if (!inputs_type) return success(); + int depth = inputs_type.getDimSize(inputs_type.getRank() - 1); + if ((min && min.getDimSize(0) != depth) || + (max && max.getDimSize(0) != depth)) { + return op.emitOpError( + "requires min and max to have same size as last dimension of inputs"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// FillOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(FillOp op) { + if (!IsOfRankOrUnranked(op.dims(), 1)) + return op.emitOpError() << "requires dims to be a 1D tensor"; + if (!IsOfRankOrUnranked(op.value(), 0)) + return op.emitOpError() << "requires value to be a scalar"; + + return success(); +} + +static ShapedType InferFillOpType(Value dims, Value value) { + Type etype = value.getType().cast().getElementType(); + + DenseIntElementsAttr dims_attr; + if (!matchPattern(dims, m_Constant(&dims_attr))) { + return UnrankedTensorType::get(etype); + } + + llvm::SmallVector shape; + shape.reserve(dims_attr.getNumElements()); + for (const APInt dim : dims_attr.getValues()) { + shape.push_back(dim.getSExtValue()); + } + return RankedTensorType::get(shape, etype); +} + +void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, + Value value) { + FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); +} + +OpFoldResult FillOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "fill op has two operand"); + + auto type = getType().cast(); + // DenseElementsAttr that is used in this folder only supports int and float + // types. + // TODO(hinsu): Handle complex types once there is a attribute kind for + // complex. + if (!type.getElementType().isIntOrFloat()) return {}; + + auto value = operands[1].dyn_cast_or_null(); + if (!value) return {}; + + if (type.hasStaticShape()) + return DenseElementsAttr::get(type, value.getValue({})); + + auto dims = operands[0].dyn_cast_or_null(); + if (!dims) return {}; + + llvm::SmallVector shape; + shape.reserve(dims.getNumElements()); + for (const APInt dim : dims.getValues()) { + shape.push_back(dim.getSExtValue()); + } + type = RankedTensorType::get(shape, type.getElementType()); + + return DenseElementsAttr::get(type, value.getValue({})); +} + +//===----------------------------------------------------------------------===// +// FusedBatchNormGradOp +//===----------------------------------------------------------------------===// + +// TODO(b/150954845): Add benchmarks to verify that layout preference didn't +// change in the latest GPU generations. + +LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef FusedBatchNormGradV3Op::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + auto x_ty = x().getType().cast(); + const bool is_f16 = x_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For all other data types prefer NCHW. + return "NCHW"; +} + +//===----------------------------------------------------------------------===// +// FusedBatchNormOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(FusedBatchNormOp op) { + auto x = GetRankedTensorTypeForOperand(op.x()); + if (x && !IsOfRankedFloatTensorType(x, 4)) + return op.emitOpError("requires x to be a 4D float tensor"); + + auto scale = GetRankedTensorTypeForOperand(op.scale()); + if (scale && !IsOfRankedFloatTensorType(scale, 1)) + return op.emitOpError("requires scale to be a 1D float tensor"); + + auto offset = GetRankedTensorTypeForOperand(op.offset()); + if (offset && !IsOfRankedFloatTensorType(offset, 1)) + return op.emitOpError("requires offset to be a 1D float tensor"); + + auto mean = GetRankedTensorTypeForOperand(op.mean()); + if (mean && !IsOfRankedFloatTensorType(mean, 1)) + return op.emitOpError("requires mean to be a 1D float tensor"); + + auto variance = GetRankedTensorTypeForOperand(op.variance()); + if (variance && !IsOfRankedFloatTensorType(variance, 1)) + return op.emitOpError("requires variance to be a 1D float tensor"); + + // TODO(antiagainst): check attributes + + return success(); +} + +//===----------------------------------------------------------------------===// +// FusedBatchNormV2Op / FusedBatchNormV3Op +//===----------------------------------------------------------------------===// + +template +static LogicalResult InferenceFoldOperandsPermutation( + ArrayRef permutation, Op *op) { + // FusedBatchNorm in training mode is a layout sentitive operation, and should + // have already assigned an optimal data format. + if (op->is_training()) return failure(); + return ::mlir::TF::FoldOperandsPermutation(permutation, op); +} + +template +static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) { + // In inference mode FusedBatchNorm is not sensitive to data layout. + if (!op->is_training()) return op->data_format(); + + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation())) + return op->data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + auto x_ty = op->x().getType().template cast(); + const bool is_f16 = x_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For all other data types prefer NCHW. + return "NCHW"; +} + +LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation( + ArrayRef permutation) { + return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this); +} + +LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) { + return ::mlir::TF::GetOptimalLayout(devices, this); +} + +LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation( + ArrayRef permutation) { + return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this); +} + +LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { + return ::mlir::TF::GetOptimalLayout(devices, this); +} + +//===----------------------------------------------------------------------===// +// GatherV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(GatherV2Op op) { + int64_t batch_dims = op.batch_dims().getSExtValue(); + if (auto ty = op.indices().getType().dyn_cast()) { + int64_t rank = ty.getRank(); + if (batch_dims > rank || batch_dims < -rank) + return op.emitOpError() + << "batch_dims (" << batch_dims << ") must be in range [" << -rank + << ", " << rank + 1 << ")"; + if (batch_dims < 0) batch_dims += rank; + } + + if (!HasRankAtMost(op.axis(), 1)) + return op.emitOpError("requires axis to have rank at most 1"); + + DenseIntElementsAttr axis_attr; + if (matchPattern(op.axis(), m_Constant(&axis_attr))) { + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (auto ty = op.params().getType().dyn_cast()) { + int64_t rank = ty.getRank(); + if (axis >= rank || axis < -rank) + return op.emitOpError() << "axis (" << axis << ") must be in range [" + << -rank << ", " << rank << ")"; + if (axis < 0) axis += rank; + } + + if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) { + return op.emitOpError() << "requires axis (" << axis + << ") to be greater than or equal to batch_dims (" + << batch_dims << ")"; + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// IfOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(IfOp op) { + auto module = op.getParentOfType(); + auto then_fn = module.lookupSymbol(op.then_branch()); + if (!then_fn) + return op.emitOpError("then_branch refers to an undefined function : ") + << op.then_branch(); + auto else_fn = module.lookupSymbol(op.else_branch()); + if (!else_fn) + return op.emitOpError("else_branch refers to an undefined function : ") + << op.else_branch(); + auto then_fn_type = then_fn.getType(); + auto else_fn_type = else_fn.getType(); + + // Non-conditional operands starting with the second operand are passed to + // branches and should be pair-wise compatible with branches' inputs. + unsigned expected_num_inputs = op.getNumOperands() - 1; + if (then_fn_type.getNumInputs() != expected_num_inputs || + else_fn_type.getNumInputs() != expected_num_inputs) + return op.emitError("branches should have " + Twine(expected_num_inputs) + + " inputs"); + + for (unsigned i = 0; i < expected_num_inputs; ++i) { + auto operand_type = op.getOperand(i + 1).getType().cast(); + auto then_input_type = then_fn_type.getInput(i).cast(); + if (!AreCastCompatible({operand_type, then_input_type})) + return op.emitError( + llvm::formatv("then branch input type {0} is incompatible with " + "operand type {1} at index {2}", + then_input_type, operand_type, i)); + + auto else_input_type = else_fn_type.getInput(i).cast(); + if (!AreCastCompatible({operand_type, else_input_type})) + return op.emitError( + llvm::formatv("else branch input type {0} is incompatible with " + "operand type {1} at index {2}", + else_input_type, operand_type, i)); + + // If branches have incompatible input types that means that no tensor can + // serve as input to both the functions. Hence, the op is invalid. + if (!AreCastCompatible({then_input_type, else_input_type})) + return op.emitError(llvm::formatv( + "branches inputs have incompatible types {0} and {1} at index {2}", + then_input_type, else_input_type, i)); + } + + // Branches' results should be pair-wise compatible with the op results. + unsigned expected_num_results = op.getNumResults(); + if (then_fn_type.getNumResults() != expected_num_results || + else_fn_type.getNumResults() != expected_num_results) + return op.emitError("branches should have " + Twine(expected_num_results) + + " results"); + + for (unsigned i = 0; i < expected_num_results; ++i) { + auto result_type = op.getResult(i).getType().cast(); + auto then_result_type = then_fn_type.getResult(i).cast(); + if (!AreCastCompatible({then_result_type, result_type})) + return op.emitError( + llvm::formatv("then branch result type {0} is incompatible with op " + "result type {1} at index {2}", + then_result_type, result_type, i)); + + auto else_result_type = else_fn_type.getResult(i).cast(); + if (!AreCastCompatible({else_result_type, result_type})) + return op.emitError( + llvm::formatv("else branch result type {0} is incompatible with op " + "result type {1} at index {2}", + else_result_type, result_type, i)); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// IfRegionOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(IfRegionOp op) { + if (failed(VerifyRegionResults(op, op.then_branch(), "then"))) + return failure(); + if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// InvertOp +//===----------------------------------------------------------------------===// + +void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// InvertPermutationOp +//===----------------------------------------------------------------------===// + +// Verifies that the input is 1D. +static LogicalResult Verify(InvertPermutationOp op) { + auto x_type = op.x().getType().cast(); + if (!x_type.hasRank()) return success(); + if (x_type.getShape().size() != 1) + return op.emitOpError() << "requires input x to be 1-dimensional"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// LeakyReluOp +//===----------------------------------------------------------------------===// + +OpFoldResult LeakyReluOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "leaky relu has one operand"); + + // leaky_relu(x, alpha: 1) -> x + if (alpha().convertToFloat() == 1.0f) return getOperand(); + + auto calculate = [&](FloatAttr arg) { + APFloat val = arg.getValue(); + if (val.isNegative()) val = alpha() * val; + return FloatAttr::get(arg.getType(), val); + }; + + if (auto arg = operands[0].dyn_cast_or_null()) { + return calculate(arg); + } else if (auto arg = operands[0].dyn_cast_or_null()) { + if (auto elementAttr = arg.getSplatValue().dyn_cast()) + return DenseElementsAttr::get(arg.getType(), calculate(elementAttr)); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// LogOp +//===----------------------------------------------------------------------===// + +void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// LogicalNotOp +//===----------------------------------------------------------------------===// + +void LogicalNotOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// MatrixBandPartOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(MatrixBandPartOp op) { + if (!HasRankAtLeast(op.input(), 2)) { + return op.emitOpError() + << "requires `input` to have rank of at least 2, but found " + << op.input().getType(); + } + if (!IsOfRankOrUnranked(op.num_lower(), 0)) { + return op.emitOpError() + << "requires `num_lower` to have 0 dimensions, but found " + << op.num_lower().getType(); + } + if (!IsOfRankOrUnranked(op.num_upper(), 0)) { + return op.emitOpError() + << "requires `num_upper` to have 0 dimensions, but found " + << op.num_upper().getType(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// MaxOp +//===----------------------------------------------------------------------===// + +void MaxOp::build(OpBuilder &builder, OperationState &result, Value input, + Value reduction_indices, BoolAttr keep_dims) { + Type out_ty = + InferReductionOpType(input, reduction_indices, keep_dims, &builder); + build(builder, result, out_ty, input, reduction_indices, keep_dims); +} + +//===----------------------------------------------------------------------===// +// MaxPoolOp +//===----------------------------------------------------------------------===// + +LogicalResult MaxPoolOp::FoldOperandsPermutation( + ArrayRef permutation) { + return ::mlir::TF::FoldOperandsPermutation( + permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); +} + +//===----------------------------------------------------------------------===// +// MaxPoolGradOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(MaxPoolGradOp op) { + if (!IsOfRankOrUnranked(op.orig_input(), 4)) { + return op.emitOpError() << "requires orig_input to be rank 4"; + } + if (!IsOfRankOrUnranked(op.orig_output(), 4)) { + return op.emitOpError() << "requires orig_output to be rank 4"; + } + if (!IsOfRankOrUnranked(op.grad(), 4)) { + return op.emitOpError() << "requires grad to be rank 4"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// MeanOp +//===----------------------------------------------------------------------===// + +LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { + // Reduction indices must be defined by a constant operation. + auto reduction_op = + dyn_cast_or_null(reduction_indices().getDefiningOp()); + if (!reduction_op) return failure(); + + auto reductions_value = reduction_op.value().dyn_cast(); + if (!reductions_value) return failure(); + + // Prepare new reduction indices according to operand permutation. + SmallVector shuffled_reduction; + llvm::transform(reductions_value.getIntValues(), + std::back_inserter(shuffled_reduction), + [&](APInt idx) { return permutation[idx.getSExtValue()]; }); + + // Add constant operation with a new reduction indices. + OpBuilder builder(getOperation()); + auto type = mlir::RankedTensorType::get(shuffled_reduction.size(), + builder.getIntegerType(32)); + auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction); + auto shuffled_reduction_op = builder.create(getLoc(), values); + + // Use new reduction indices. + setOperand(1, shuffled_reduction_op); + + return success(); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc" + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h new file mode 100644 index 00000000000..b2b78da8993 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +namespace mlir { +namespace TF { + +class YieldOp; + +// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose +// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use +// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and +// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to +// undefine here to avoid expanding the getter symbol as macro when including +// both mutex.h and this header file. +#undef mutex_lock + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h.inc" + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc new file mode 100644 index 00000000000..cea2aa17d46 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -0,0 +1,580 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a simple include file used to simplify the splitting of the +// tf_ops.cc file. The helpers in here should be refactored and moved to +// tf_verifiers or tf_ops. +// TODO(jpienaar): Remove this file post refactoring. + +// Propagates underscore and device attributes from src to dst. +// TODO(b/158769932): This should be a general feature instead post some policy +// discussion. +static void PropagateAttributes(Operation *src, Operation *dst) { + auto device = mlir::Identifier::get("device", src->getContext()); + for (auto named_attr : src->getAttrs()) { + if (*named_attr.first.begin() == '_' || named_attr.first == device) + dst->setAttr(named_attr.first, named_attr.second); + } +} + +//===----------------------------------------------------------------------===// +// TF op helper functions +//===----------------------------------------------------------------------===// + +// Returns the RankedTensorType for the given operand. TensorFlow constant ops +// may have non-static shape because the shape is not propagated during constant +// folding. If the defining op for the given operand is a constant op, this +// routine uses the constant op's attribute to get the actual shape. +static RankedTensorType GetRankedTensorTypeForOperand(Value operand) { + DenseElementsAttr attr; + if (matchPattern(operand, m_Constant(&attr))) { + return attr.getType().dyn_cast(); + } + return operand.getType().dyn_cast(); +} + +// Returns true if the given `value` is of ranked float tensor type with the +// given `rank`. +static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { + return type && type.getRank() == rank && + type.getElementType().isa(); +} + +// Returns true if the given `value` has the specified rank or has unranked +// type. +static inline bool IsOfRankOrUnranked(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() == rank; +} + +// Returns true if the given `value` has at least the specified rank or has +// unranked type. +static inline bool HasRankAtLeast(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() >= rank; +} + +// Returns true if the given `value` has at most the specified rank or has +// unranked type. +static inline bool HasRankAtMost(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() <= rank; +} + +static bool IsUnknownDimOrRank(int64_t dim_or_rank) { + return dim_or_rank == -1; +} + +// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If +// `incompatible_shape_error` is true, reports error if `x` and `y` has +// incompatible shapes. Otherwise, returns a tensor type with unknown rank. +static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, + Value y, BoolAttr incompatible_shape_error) { + auto result_type = + OpTrait::util::getBroadcastedType(x.getType(), y.getType()); + if (!result_type) { + if (incompatible_shape_error.getValue()) { + mlir::emitError(loc, "non-broadcastable operands"); + } else { + return UnrankedTensorType::get(builder->getI1Type()); + } + } + + auto ranked_type = result_type.dyn_cast(); + if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type()); + + return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type()); +} + +// Returns dimension index for the given TensorFlow axis that supports negative +// indexing. +static int64_t GetDimForAxis(int64_t axis, int64_t rank) { + return axis >= 0 ? axis : axis + rank; +} + +// Infers output type for reduction ops such as SumOp, MaxOp etc. +// TODO(b/e667204a): Move this logic to shape inference once it supports custom +// inference functions. +static Type InferReductionOpType(Value input, Value reduction_indices, + BoolAttr keep_dims, Builder *builder) { + Type input_ty = input.getType(); + Type element_ty = getElementTypeOrSelf(input_ty); + + // Output type is unranked if input type is not ranked. + auto ranked_ty = input_ty.dyn_cast(); + if (!ranked_ty) return UnrankedTensorType::get(element_ty); + int64_t rank = ranked_ty.getRank(); + + DenseIntElementsAttr indices; + if (!matchPattern(reduction_indices, m_Constant(&indices))) { + // Output type is unranked if reduction indices are not constant and reduced + // dimensions are not kept. + if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty); + + // Otherwise, output type has same rank as the input. + return RankedTensorType::get(SmallVector(rank, -1), element_ty); + } + + int64_t num_reduce_dim = 0; + llvm::SmallVector is_reduce_dim(rank, false); + for (const APInt &index : indices.getValues()) { + int64_t dim = GetDimForAxis(index.getSExtValue(), rank); + // Invalid input. + if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty); + + if (!is_reduce_dim[dim]) { + is_reduce_dim[dim] = true; + num_reduce_dim++; + } + } + + ArrayRef shape = ranked_ty.getShape(); + SmallVector out_shape; + out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim)); + for (int64_t i = 0; i < rank; ++i) { + if (!is_reduce_dim[i]) + out_shape.push_back(shape[i]); + else if (keep_dims.getValue()) + out_shape.push_back(1); + } + return RankedTensorType::get(out_shape, element_ty); +} + +// Verifies that the given types are cast compatible. If not, emits appropriate +// error for the given op. If mask_one_dim is set to true, then the types are +// allowed to have one mismatching dimension. Masking one of the dimensions is +// useful for ops like Concat that requires all ranked inputs to have the same +// rank and match dimension sizes for all but one of the dimensions. +static LogicalResult VerifyTypesCompatibility( + Operation::operand_type_range types, bool mask_one_dim, Operation *op) { + constexpr int64_t kUninitialized = -1; + int64_t common_rank = kUninitialized; + llvm::SmallVector common_dims; + int64_t dim_to_mask = kUninitialized; + + // Initialize common_rank with rank of the first ranked type and verify that + // following ranked types have the same rank. + // Similarly, initialize each of the dimensions with the first type that has + // the dimension size available and verify that all following types have the + // same size for the dimension. However, if mask_one_dim is true, note down + // the dimension index on the first mismatch and ignore dimension at that + // index in following types. + for (Type ty : types) { + RankedTensorType ranked_ty = ty.dyn_cast(); + if (!ranked_ty) continue; + + int64_t rank = ranked_ty.getRank(); + if (common_rank == kUninitialized) { + common_rank = rank; + common_dims.resize(common_rank, kUninitialized); + } else if (common_rank != rank) { + return op->emitError() + << "operand type " << ranked_ty + << " is not compatible with preceding operands; expected rank: " + << common_rank; + } + + for (int64_t i = 0, e = common_rank; i != e; i++) { + if (i == dim_to_mask) continue; + + int64_t dim = ranked_ty.getDimSize(i); + if (dim == kUninitialized) continue; + + int64_t &common_dim = common_dims[i]; + if (common_dim == kUninitialized) { + common_dim = dim; + } else if (common_dim != dim) { + // If mask_one_dim is true, do not emit an error if this is the only + // dimension with mismatches. Note down the dimension to mask it from + // the following types. + if (mask_one_dim && dim_to_mask == kUninitialized) { + dim_to_mask = i; + continue; + } + + return op->emitError() << "operand type " << ranked_ty + << " is not compatible with preceding operands; " + "expected dimension at index " + << i << ": " << common_dim; + } + } + } + return success(); +} + +// This is a helper for the Select to SelectV2 canonicalization. The `data` rank +// refers to the rank of `t`/`e` (these two inputs have equal rank; this is +// checked in the verifier). +// +// In most cases, the predicate for Select can be used directly as the predicate +// for SelectV2. However, there is one case that varies, which is when the +// predicate is a tensor and the data is multidimensional. In this case, Select +// op semantics dictate that the predicate tensor length must match the size of +// the first data dimension. This varies from normal broadcasting semantics +// (which are used in SelectV2), so we must reshape the tensor in this case to +// be compatible. +static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, + Value cond, int data_rank) { + auto cond_tensor = cond.getType().cast(); + // Reshape is only needed in the case that the cond rank is 1 (i.e. it is + // a vector) AND t/e rank is > 1. + if (cond_tensor.getRank() != 1 || data_rank <= 1) { + // No reshape necessary. Leave cond as it is. + return cond; + } + + // This is the case where a reshape is needed. We want to construct the + // shape [x,1,...1], where x is the value in the pred tensor and the + // length of the shape is equal to data_rank. + SmallVector shape(data_rank, 1); + shape[0] = cond_tensor.getShape().front(); + auto new_shape_type = + RankedTensorType::get({data_rank}, builder->getIntegerType(64)); + auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); + auto new_shape = builder->create(loc, shape_attr); + return builder->create(loc, cond, new_shape); +} + +//===----------------------------------------------------------------------===// +// Helper functions detect device capabilities from RuntimeDevices. +//===----------------------------------------------------------------------===// + +namespace { +using DeviceNameUtils = ::tensorflow::DeviceNameUtils; +using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName; + +bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) { + return device.type == ::tensorflow::DEVICE_GPU; +} + +} // namespace + +// Returns true if at least one GPU device is available at runtime. +bool CanUseGpuDevice(const RuntimeDevices &devices) { + return llvm::any_of(devices.device_names(), IsGpuDevice); +} + +// Returns true if all of the GPUs available at runtime support TensorCores +// (NVIDIA compute capability >= 7.0). +bool CanUseTensorCores(const RuntimeDevices &devices) { + auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) { + auto md = devices.GetGpuDeviceMetadata(device); + return md ? md->cc_major().getInt() >= 7 : false; + }; + return llvm::all_of( + llvm::make_filter_range(devices.device_names(), IsGpuDevice), + has_tensor_cores); +} + +// Returns true if operation does not have explicit device placement that would +// prevent it from running on GPU device. +bool CanUseGpuDevice(Operation *op) { + auto device_attr = op->getAttrOfType("device"); + if (!device_attr || device_attr.getValue().empty()) return true; + + DeviceNameUtils::ParsedName device; + if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device)) + return false; + + // We can't use GPU if operation explicitly placed on non-GPU device. + return !device.has_type || device.type == ::tensorflow::DEVICE_GPU; +} + +//===----------------------------------------------------------------------===// +// TF op helper functions to work with layout transformation. +//===----------------------------------------------------------------------===// + +SmallVector ReversePermutation(ArrayRef permutation) { + SmallVector reverse(permutation.size()); + for (size_t i = 0; i < permutation.size(); ++i) { + reverse[permutation[i]] = i; + } + return reverse; +} + +SmallVector GetDataFormatPermutation(StringRef from, StringRef to) { + if (from == "NHWC" && to == "NCHW") { + return {0, 3, 1, 2}; + } else if (from == "NCHW" && to == "NHWC") { + return {0, 2, 3, 1}; + } else { + return {}; + } +} + +// Shuffle elements in the `attr` according to the permutation. Optional +// `inner_size` allows to shuffle array attributes created from rank 2 tensors +// on outer dimension only. +ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef permutation, + int inner_size = 1) { + if (attr.size() == 0) return attr; + + assert(attr.size() % inner_size == 0); + assert(attr.size() / inner_size == permutation.size()); + + SmallVector values{attr.begin(), attr.end()}; + SmallVector shuffled(values.size()); + + for (size_t i = 0; i < permutation.size(); ++i) { + for (size_t j = 0; j < inner_size; ++j) { + shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j]; + } + } + + return ArrayAttr::get(shuffled, attr.getContext()); +} + +// Shuffle ranked tensor dimensions according to the permutation. +Type ShuffleRankedTensorType(Type type, ArrayRef permutation) { + if (auto ranked_type = type.dyn_cast()) { + ArrayRef shape = ranked_type.getShape(); + assert(permutation.size() == shape.size()); + + SmallVector new_shape(permutation.size()); + for (size_t i = 0; i < permutation.size(); ++i) + new_shape[i] = shape[permutation[i]]; + + return RankedTensorType::get(new_shape, ranked_type.getElementType()); + } + + return type; +} + +static bool AreCancellablePermutations(DenseIntElementsAttr perm0, + DenseIntElementsAttr perm1) { + if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false; + if (perm0.getNumElements() != perm1.getNumElements()) return false; + + SmallVector perm0_values; + for (const auto &value : perm0.getIntValues()) + perm0_values.push_back(value.getSExtValue()); + + SmallVector perm1_values; + for (const auto &value : perm1.getIntValues()) + perm1_values.push_back(value.getSExtValue()); + + for (int i = 0; i < perm0_values.size(); ++i) { + if (perm0_values[perm1_values[i]] != i) return false; + } + + return true; +} + +// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for +// layout sensitive operations that do not have any additional layout dependent +// attributes besides `data_format` string. +template +LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { + auto perm = GetDataFormatPermutation(op->data_format(), data_format); + if (perm.empty()) return failure(); + + // Update data format attribute. + op->setAttr("data_format", StringAttr::get(data_format, op->getContext())); + + // Update types for all layout sensitive results. + auto layout_sensitive = cast(op->getOperation()); + for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) { + OpResult result = op->getOperation()->getResult(idx); + result.setType(ShuffleRankedTensorType(result.getType(), perm)); + } + + return success(); +} + +// Default implementation for folding operand transpose into the operation. +// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`. +template +LogicalResult FoldOperandsPermutation( + ArrayRef permutation, Op *op, + ArrayRef> shuffle_attrs = {}) { + MLIRContext *context = op->template getParentOfType().getContext(); + + // We only support NHWC <-> NCHW permutations. + static constexpr std::array kNchwToNhwc = {0, 2, 3, 1}; + static constexpr std::array kNhwcToNchw = {0, 3, 1, 2}; + + // Operation data format after folding `permutation`. + StringRef target_data_format = [&]() -> StringRef { + if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) { + return "NCHW"; // cancel NCHW->NHWC operand permutation + } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) { + return "NHWC"; // cancel NHWC->NCHW operand permutation + } else { + return ""; + } + }(); + if (target_data_format.empty()) return failure(); + + // To fold operand `permutation` into the `op` we need shuffle all layout + // dependent attributes and types with a reverse permutation, and change + // operation data format to `target_data_format`. + // + // Example: + // %1 = SomeOp(...) {data_format = NHWC} + // %2 = Transpose(%1) {permutation = NHWC->NCHW} + // %3 = Op(%2) {data_format = NCHW} + // + // To bypass %2 we have to change data format to shuffle data format from NCHW + // to NHWC, which is the reverse of operand permutation (function argument). + auto reverse_permutation = + GetDataFormatPermutation(op->data_format(), target_data_format); + if (reverse_permutation.empty()) return failure(); + + op->setAttr("data_format", StringAttr::get(target_data_format, context)); + + for (auto pair : shuffle_attrs) { + StringRef attr_name = pair.first; + ArrayAttr attr_value = pair.second; + op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation)); + } + + auto fold = cast(op->getOperation()); + for (unsigned idx : fold.GetLayoutDependentResults()) { + OpResult result = op->getOperation()->getResult(idx); + result.setType( + ShuffleRankedTensorType(result.getType(), reverse_permutation)); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Rewrite Pattern for removing trivial Arithmetic op. +//===----------------------------------------------------------------------===// + +namespace { +// Fold Arithmetic Op if one of the operands is a constant known to be an +// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if +// known identity value is either lhs or rhs. +template < + typename OpT, + typename std::enable_if::value>::type * = nullptr> +OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, + ArrayRef operands) { + auto lhs_type = arithmetic_op.x().getType().template cast(); + auto rhs_type = arithmetic_op.y().getType().template cast(); + auto result_type = + arithmetic_op.getResult().getType().template cast(); + + // We can fold arithmetic operation only of we can prove that we will not + // accidentally hide a broadcasting error. + auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty, + ShapedType result_ty) -> bool { + // Scalar identity is broadcastable to any operand shape, we only need to + // check that operand has the same shape as a result. + bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; + if (scalar_identity) return operand_ty == result_ty; + + // If identity is not a scalar, we must verify that all shapes are equal + // and statically known. + // + // TODO(ezhulenev): Fold if identity shape is statically know to be + // broadcastable to the operand shape. + return operand_ty == result_ty && identity_ty == result_ty && + result_ty.hasStaticShape(); + }; + + // Check that we have a constant operand on one side (candidate for identity). + const bool is_commutative = + (std::is_same::value || std::is_same::value); + auto lhs_attr = operands[0].dyn_cast_or_null(); + auto rhs_attr = operands[1].dyn_cast_or_null(); + if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; + + // Mul and Div ops have identity value one while AddV2 and SubOp have identity + // value zero. + const int identity = + (std::is_same::value || std::is_same::value || + std::is_same::value) + ? 1 + : 0; + + Type element_ty = lhs_type.getElementType(); + Attribute identity_attr; + if (auto ty = element_ty.template dyn_cast()) { + identity_attr = FloatAttr::get(ty, static_cast(identity)); + } else if (auto ty = element_ty.template dyn_cast()) { + identity_attr = IntegerAttr::get(ty, static_cast(identity)); + } else { + return {}; + } + + // Fold: Op(Operand, Identity) -> Operand. + if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { + if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr) + return arithmetic_op.x(); + } + + // Fold: Op(Identity, Operand) -> Operand for commutative operations. + if (lhs_attr && is_commutative && + is_valid_broadcasting(rhs_type, lhs_type, result_type)) { + if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr) + return arithmetic_op.y(); + } + + return {}; +} +} // namespace + +// Verifies an reduction op's `input` and reduction `dims`. +static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, + Location loc) { + auto dims_type = dims.getType().dyn_cast(); + if (!dims_type) return success(); + if (dims_type.getRank() > 1) + return emitError(loc, "dimensions can only be 0D or 1D tensor"); + + auto input_type = input.getType().dyn_cast(); + if (!input_type) return success(); + int64_t rank = input_type.getRank(); + + DenseIntElementsAttr dims_attr; + if (!matchPattern(dims, m_Constant(&dims_attr))) return success(); + for (const auto &dim_pair : llvm::enumerate(dims_attr)) { + int64_t cur_dim = dim_pair.value().getSExtValue(); + if (cur_dim < -rank || cur_dim >= rank) + return emitError(loc) + << dim_pair.index() << "-th dimension should be in the range of [-" + << rank << ", " << rank << ")"; + } + + return success(); +} + +LogicalResult VerifyRegionResults(Operation *op, Region ®ion, + StringRef region_name) { + auto op_name = op->getName().getStringRef(); + // verify that op outputs match yield inputs + YieldOp yield = cast(region.front().getTerminator()); + unsigned expected_num_results = op->getNumResults(); + if (yield.getNumOperands() != expected_num_results) + return op->emitOpError() + << region_name + " should have same number (" << expected_num_results + << ") of results as " << op_name << " but has " + << yield.getNumOperands() << " results"; + + for (int idx : llvm::seq(0, expected_num_results)) { + auto op_result_type = op->getResult(idx).getType().cast(); + auto region_result_type = + yield.getOperand(idx).getType().cast(); + if (!AreCastCompatible({region_result_type, op_result_type})) + return op->emitError(llvm::formatv( + "{0} result type {1} is incompatible with {2} " + "result type {3} at index {4}", + region_name, region_result_type, op_name, op_result_type, idx)); + } + return success(); +} diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc new file mode 100644 index 00000000000..c5c729a600e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -0,0 +1,2270 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace mlir { +namespace TF { + +namespace { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" +} // namespace + +//===----------------------------------------------------------------------===// +// NegOp +//===----------------------------------------------------------------------===// + +void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// NotEqualOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(NotEqualOp op) { + // If we allow inputs to have incompatible type, then nothing to do. + if (!op.incompatible_shape_error()) return success(); + + // Otherwise, check inputs are broadcastable. + return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( + op.getOperation()); +} + +void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x, + Value y, BoolAttr incompatible_shape_error) { + auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, + incompatible_shape_error); + return build(builder, result, result_type, x, y, incompatible_shape_error); +} + +//===----------------------------------------------------------------------===// +// OneHotOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(OneHotOp op) { + int64_t axis = op.axis().getSExtValue(); + + auto indices_ty = op.indices().getType().dyn_cast(); + if (indices_ty && + !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { + return op.emitOpError() + << "expected axis (" << axis << ") to be -1 or between [0, " + << indices_ty.getShape().size() << "]"; + } + + if (axis < -1) { + return op.emitOpError() << "expected axis (" << axis + << ") to be -1 or between [0, rank(indices()))"; + } + + if (!IsOfRankOrUnranked(op.depth(), 0)) { + return op.emitOpError() << "requires depth to be a scalar"; + } + if (!IsOfRankOrUnranked(op.on_value(), 0)) { + return op.emitOpError() << "requires on_value to be a scalar"; + } + if (!IsOfRankOrUnranked(op.off_value(), 0)) { + return op.emitOpError() << "requires off_value to be a scalar"; + } + + DenseIntElementsAttr depth_attr; + if (matchPattern(op.depth(), m_Constant(&depth_attr))) { + if (depth_attr.getType().getRank() != 0) + return op.emitOpError() << "requires depth to be a scalar"; + int64_t depth = depth_attr.getValue({}).getSExtValue(); + if (depth < 0) { + return op.emitOpError() << "depth must be non-negative, got: " << depth; + } + } + + return success(); +} + +static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, + Value off_value, IntegerAttr axis) { + int64_t axis_val = axis.getInt(); + Type element_ty = on_value.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + if (axis_val < -1) return unranked_ty; + + auto indices_ty = indices.getType().dyn_cast(); + if (!indices_ty) return unranked_ty; + + auto shape = llvm::to_vector<2>(indices_ty.getShape()); + if (axis_val == -1) axis_val = shape.size(); + + int64_t depth_val = ShapedType::kDynamicSize; + DenseIntElementsAttr depth_attr; + if (matchPattern(depth, m_Constant(&depth_attr)) && + depth_attr.getNumElements() == 1) + depth_val = (*depth_attr.begin()).getSExtValue(); + shape.insert(shape.begin() + axis_val, depth_val); + return RankedTensorType::get(shape, element_ty); +} + +void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices, + Value depth, Value on_value, Value off_value, + IntegerAttr axis) { + build(builder, result, + InferOneHotOpType(indices, depth, on_value, off_value, axis), indices, + depth, on_value, off_value, axis); +} + +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(PackOp op) { + // TODO(hinsu): Convert variadic length attributes to derived attributes. + Operation::operand_range values = op.values(); + + if (failed(VerifyTypesCompatibility(values, + /*mask_one_dim=*/false, + op.getOperation()))) { + return failure(); + } + + int64_t inputs_rank = -1; + for (Value value : values) { + if (auto ty = value.getType().dyn_cast()) { + // Exit early as input types are verified to be compatible so all ranked + // tensors have the same rank. + inputs_rank = ty.getRank(); + break; + } + } + if (inputs_rank == -1) return success(); + + // The values can be packed along any of the dimensions between 0 and + // inputs rank, inclusive. Also, as the negative axis values wrap around so + // the axis value range is [-(R+1), R+1). + int64_t range_begin = -inputs_rank - 1; // Inclusive + int64_t range_end = inputs_rank + 1; // Exclusive + int64_t axis = op.axis().getSExtValue(); + if (axis < range_begin || axis >= range_end) { + return op.emitError() << "attribute 'axis' should be within range [" + << range_begin << ", " << range_end + << "); actual value: " << axis; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// PadOp +//===----------------------------------------------------------------------===// + +LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { + // Paddings must be defined by a constant operation. + auto paddings_op = dyn_cast_or_null(paddings().getDefiningOp()); + if (!paddings_op) return failure(); + + auto paddings_value = paddings_op.value().dyn_cast(); + if (!paddings_value || + paddings_value.getNumElements() != permutation.size() * 2) + return failure(); + + SmallVector shuffled_paddings(paddings_value.getNumElements()); + for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) { + size_t outer_idx = index_pair.index() / 2; + size_t inner_idx = index_pair.index() % 2; + + shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] = + index_pair.value().getSExtValue(); + } + + // Add constant operation with a new paddings. + OpBuilder builder(getOperation()); + auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(), + builder.getIntegerType(32)); + auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings); + auto shuffled_paddings_op = builder.create(getLoc(), values); + + // Use new paddings. + setOperand(1, shuffled_paddings_op); + + // Change the result type. + getResult().setType(ShuffleRankedTensorType(getResult().getType(), + ReversePermutation(permutation))); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ParseExampleV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ParseExampleV2Op op) { + // NOTE(mrry): This validates properties of an op that would previously be + // validated by the TensorFlow OpDef type checker. In addition to these + // checks, the shape inference function for ParseExampleV2 validates the + // consistency of the argument and result types. + + // Validate dense variadic input and output lengths. + // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we + // do not need to validate dense_defaults. + auto dense_types_count = + std::distance(op.Tdense().begin(), op.Tdense().end()); + auto dense_values_count = + std::distance(op.dense_values().begin(), op.dense_values().end()); + if (dense_values_count != dense_types_count) { + return op.emitError() << "output 'dense_values' should have same length " + << "as attribute 'Tdense'"; + } + + // Validate sparse variadic output lengths. + // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we + // do not need to validate sparse_values. + auto sparse_types_count = + std::distance(op.sparse_types().begin(), op.sparse_types().end()); + if (op.num_sparse() != sparse_types_count) { + return op.emitError() << "attribute 'num_sparse' should be the same as " + << "the length of attribute 'sparse_types'"; + } + if (op.sparse_indices().size() != sparse_types_count) { + return op.emitError() << "output 'sparse_indices' should have same length " + << "as attribute 'sparse_types'"; + } + if (op.sparse_shapes().size() != sparse_types_count) { + return op.emitError() << "output 'sparse_shapes' should have same length " + << "as attribute 'sparse_types'"; + } + + // Validate ragged variadic output lengths. + auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(), + op.ragged_value_types().end()); + auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(), + op.ragged_split_types().end()); + if (ragged_value_types_count != ragged_split_types_count) { + return op.emitError() << "attribute 'ragged_value_types' should have same " + << "length as attribute 'ragged_split_types'"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// PartitionedCallOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult VerifyPartitionedCall(OpClass op) { + auto module = op.template getParentOfType(); + SymbolRefAttr func = op.getAttr("f").template cast(); + + auto function = + dyn_cast_or_null(SymbolTable::lookupSymbolIn(module, func)); + + if (!function) { + return op.emitError("'f' attribute refers to an undefined function: ") + << func; + } + + FunctionType function_ty = function.getType(); + int func_arg_count = function_ty.getNumInputs(); + int arg_count = op.args().size(); + + if (arg_count != func_arg_count) { + return op.emitError() << "argument count mismatch: 'args' has " << arg_count + << " arguments, but '" << func << "' expects " + << func_arg_count; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// PowOp +//===----------------------------------------------------------------------===// + +OpFoldResult PowOp::fold(ArrayRef operands) { + auto constant_y = operands[1].dyn_cast_or_null(); + if (constant_y && constant_y.isSplat()) { + APFloat y_value = constant_y.getSplatValue(); + auto output_type = getType().cast(); + if (y_value.isZero() && output_type.hasStaticShape()) { + return DenseElementsAttr::get( + output_type, + FloatAttr::get(output_type.getElementType(), /*value=*/1.0)); + } + if (y_value.isExactlyValue(1.0)) { + return x(); + } + } + return {}; +} + +//===----------------------------------------------------------------------===// +// QrOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// * Input type, if ranked, must have at least 2 dimensions and at most +// INT32_MAX dimensions. +// +static LogicalResult Verify(QrOp op) { + auto ttype = op.input().getType().cast(); + if (!ttype.hasRank()) return success(); + if (!HasRankAtLeast(op.input(), 2)) + return op.emitOpError( + "requires ranked input tensor to be of rank 2 or more"); + if (!HasRankAtMost(op.input(), std::numeric_limits::max())) + return op.emitOpError( + "requires ranked input tensor to be of rank INT32_MAX or less"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ReadVariableOp +//===----------------------------------------------------------------------===// + +void ReadVariableOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// ReciprocalOp +//===----------------------------------------------------------------------===// + +void ReciprocalOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// RandomUniformOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(RandomUniformOp op) { + if (!IsOfRankOrUnranked(op.shape(), 1)) + return op.emitOpError("shape must be 1D tensor"); + return success(); +} + +//===----------------------------------------------------------------------===// +// RangeOp +//===----------------------------------------------------------------------===// + +void RangeOp::build(OpBuilder &builder, OperationState &result, Value start, + Value limit, Value delta) { + assert(start.getType() == limit.getType()); + assert(start.getType() == delta.getType()); + DenseIntElementsAttr start_val; + DenseIntElementsAttr limit_val; + DenseIntElementsAttr delta_val; + if (matchPattern(start, m_Constant(&start_val)) && + matchPattern(limit, m_Constant(&limit_val)) && + matchPattern(delta, m_Constant(&delta_val))) { + auto size = llvm::APIntOps::RoundingSDiv( + *limit_val.begin() - *start_val.begin(), *delta_val.begin(), + llvm::APInt::Rounding::DOWN); + return RangeOp::build( + builder, result, + RankedTensorType::get( + size.getSExtValue(), + start.getType().cast().getElementType()), + start, limit, delta); + } + return RangeOp::build( + builder, result, + RankedTensorType::get( + {-1}, start.getType().cast().getElementType()), + start, limit, delta); +} +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +void RankOp::build(OpBuilder &builder, OperationState &result, Value input) { + return RankOp::build(builder, result, + RankedTensorType::get({}, builder.getIntegerType(32)), + input); +} + +// This will create a constant value for RankOp of a ranked tensor. +OpFoldResult RankOp::fold(ArrayRef operands) { + auto type = input().getType(); + auto ranked_type = type.dyn_cast(); + if (!ranked_type) return {}; + + auto output_type = getType().cast(); + int32_t rank = ranked_type.getRank(); + return DenseIntElementsAttr::get(output_type, rank); +} + +//===----------------------------------------------------------------------===// +// RealDivOp +//===----------------------------------------------------------------------===// + +void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult RealDivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +// TODO(b/128020684): Verify the output type. +static LogicalResult Verify(ReshapeOp op) { + auto shape_type = op.shape().getType().cast(); + if (!shape_type.hasRank()) return success(); + if (shape_type.getRank() != 1) + return op.emitOpError("shape must be 1D tensor"); + auto rank_by_shape = shape_type.getShape()[0]; + auto type_of_tensor = op.tensor().getType().cast(); + // No compile time verification for unknown sized shape. + if (rank_by_shape == -1 || !type_of_tensor.hasStaticShape()) return success(); + int64_t num_by_tensor = type_of_tensor.getNumElements(); + + auto out_ty = op.getType().dyn_cast(); + if (out_ty && out_ty.hasStaticShape()) { + int64_t num_output_elements = out_ty.getNumElements(); + if (num_by_tensor != num_output_elements) + return op.emitOpError() + << "number of output elements (" << num_output_elements + << ") does not match expected number of elements (" + << num_by_tensor << ")"; + } + + // Check values if constant shape. No compiling time verification for + // non-constant shape. + auto *shape_op = op.shape().getDefiningOp(); + if (!shape_op) return success(); + Attribute shape_cst; + if (!matchPattern(shape_op, m_Constant(&shape_cst))) return success(); + auto shape_cst_attr = shape_cst.dyn_cast(); + if (!shape_cst_attr) return op.emitOpError("shape must be a valid tensor"); + + if (auto opaque_attr = shape_cst_attr.dyn_cast()) { + opaque_attr.decode(shape_cst_attr); + } + + // We know the shape is a 1-D Tensor, then let us get the number of + // elements it implies. + unsigned num_by_shape = 1; + unsigned unknown_dim_count = 0; + for (int i = 0, e = rank_by_shape; i != e; ++i) { + auto num = shape_cst_attr.getValue(i).getInt(); + // The dimension size value can be -1, and that the real size needs to + // be computed so that the total size remains constant. At most one + // component of shape can be -1. + if (num == -1) { + if (++unknown_dim_count > 1) { + return op.emitOpError("more than one component of shape are -1"); + } + } else { + num_by_shape *= num; + } + } + // If there is one component of shape is -1, the dimension should be + // computed so that the total size remains constant. + if (unknown_dim_count == 1) { + if (num_by_tensor % num_by_shape != 0) + return op.emitOpError( + "one component of shape is -1 but couldn't infer the dimension"); + return success(); + } + // If the elements by the tensor and implies by the shape don't match, + // fail this static check. + if (num_by_tensor != num_by_shape) { + return op.emitOpError( + "mismatch in tensor elements and shape implied elements"); + } + return success(); +} + +void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, + Value shape) { + auto ttype = tensor.getType().cast(); + auto etype = ttype.getElementType(); + + auto unranked = [&builder, etype, &result, shape, tensor]() { + return ReshapeOp::build(builder, result, UnrankedTensorType::get(etype), + tensor, shape); + }; + + // If tensor is unranked then we have no info about output of shape. + if (!ttype.hasRank()) return unranked(); + + DenseIntElementsAttr attr_shape; + if (matchPattern(shape, m_Constant(&attr_shape))) { + llvm::SmallVector const_shape; + const_shape.reserve(attr_shape.getNumElements()); + + // Detect if reshape output shape is folded. + bool flatten = false; + int unknown_index = -1; + // The product of constant shape argument excluding unknown dimension. + int64_t product_cshape = 1; + for (auto e : llvm::enumerate(attr_shape)) { + int64_t val = e.value().getSExtValue(); + if (IsUnknownDimOrRank(val)) { + if (flatten) { + mlir::emitError(result.location) + << "only one unknown dimension allowed"; + return; + } + flatten = true; + unknown_index = e.index(); + } else { + product_cshape *= val; + } + const_shape.push_back(val); + } + + // Compute the value of the unknown dimension. + if (flatten) { + // Compute number of elements in tensor shape. + auto tshape = ttype.getShape(); + int64_t product_tshape = std::accumulate(tshape.begin(), tshape.end(), 1, + std::multiplies()); + // Set the unknown dimension such that total number of elements remain + // constant. + // Note: The case where the ratio is not integral, and so the total size + // of reshape not constant, is checked in verify function. + const_shape[unknown_index] = product_tshape / product_cshape; + } + return ReshapeOp::build(builder, result, + RankedTensorType::get(const_shape, etype), tensor, + shape); + } + return unranked(); +} + +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult ReshapeOp::fold(ArrayRef operands) { + Value tensor = this->tensor(); + Value shape = this->shape(); + + // Fold reshape if operand and result types are the same and all dimensions + // are statically known (no-op reshape). + // TODO(ezhulenev): Add the same folding for BroadcastToOp. + auto result_ty = getType().dyn_cast(); + if (result_ty && result_ty.hasStaticShape() && + result_ty == tensor.getType()) { + return tensor; + } + + // Fold reshape if the shape is computed from the input tensor: + // + // %shape = tf.Shape(%arg) // [? x ...] + // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value + // %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...] + // %reshape = tf.Reshape(%arg, %new_shape) // this is no-op + // + // Where `...` are some statically known dimensions. In this case reshape is + // a no-op and can be replaced by %arg (assuming `...` are equal). + auto pack_op = dyn_cast_or_null(shape.getDefiningOp()); + if (!pack_op || pack_op.values().size() < 2) return {}; + + // Dimensions packed along axis = 0 (pack scalars into vector). + if (pack_op.axis().getSExtValue() != 0) return {}; + + // First packed value is defined by a strided slice operation. + auto slice_op = + dyn_cast_or_null(pack_op.values()[0].getDefiningOp()); + if (!slice_op) return {}; + + // Input to the slice op is defined by shape operation. + auto shape_op = dyn_cast_or_null(slice_op.input().getDefiningOp()); + if (!shape_op || shape_op.input() != tensor) return {}; + + // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing + // scalar value from input vector). + if (slice_op.begin_mask().getSExtValue() != 0 || + slice_op.ellipsis_mask().getSExtValue() != 0 || + slice_op.end_mask().getSExtValue() != 0 || + slice_op.new_axis_mask().getSExtValue() != 0 || + slice_op.shrink_axis_mask().getSExtValue() != 1) + return {}; + + // Returns a value if the `value` is defined by a ConstOp with a single + // integer element in it and has an expected rank. + auto get_value = [](Value value, int expected_rank) -> Optional { + auto const_op = dyn_cast_or_null(value.getDefiningOp()); + if (!const_op) return None; + + auto value_attr = const_op.value().dyn_cast(); + if (!value_attr || value_attr.getNumElements() != 1) return None; + + auto value_ty = value_attr.getType(); + if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None; + + auto splat = value_attr.getSplatValue(); + return splat.getValue().getSExtValue(); + }; + + // All other packed values are scalar constants. + SmallVector packed_dims; + packed_dims.reserve(pack_op.values().size() - 1); + for (Value operand : llvm::drop_begin(pack_op.values(), 1)) { + if (auto dim = get_value(operand, /*expected_rank=*/0)) { + packed_dims.push_back(*dim); + } else { + return {}; + } + } + + // Slice exactly the first shape dimension: + // begin = [0] end = [1], strides = [1] + auto begin = get_value(slice_op.begin(), /*expected_rank=*/1); + auto end = get_value(slice_op.end(), /*expected_rank=*/1); + auto strides = get_value(slice_op.strides(), /*expected_rank=*/1); + if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() || + *begin != 0 || *end != 1 || *strides != 1) + return {}; + + // First tensor dimension is dynamic. + auto arg_ty = tensor.getType().dyn_cast(); + if (!arg_ty || arg_ty.getNumDynamicDims() != 1 || !arg_ty.isDynamicDim(0)) + return {}; + + // Argument tensor rank is equal to the number of packed dimensions. + if (arg_ty.getRank() != pack_op.values().size()) return {}; + + // All other dimensions are statically known and equal to packed dims. + auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1); + if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin())) + return {}; + + return tensor; +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +// Verifies a few extra requirements on SelectOp: +// (1) `then` and `else` must have same shape +// (2) At least one of the following must be true: +// (a) `cond` has the same rank as `then` and `else` +// (b) `cond` is a scalar +// (c) `cond` is a vector AND `then` and `else` are non-scalar with their +// first dimension equal to `cond`. +static LogicalResult Verify(SelectOp op) { + auto then_tensor = op.t().getType().cast(); + auto else_tensor = op.e().getType().cast(); + // Check (1). + if (!AreCastCompatible({then_tensor, else_tensor})) + return op.emitOpError() << "requires t and e have compatible shapes"; + + // Get data rank (if exists). + int data_rank; + // If data is unranked or data_rank is 0, this will remain -2. Otherwise + // refers to first dimension of then and/or else. + int data_first_dim = -2; + bool then_has_rank = then_tensor.hasRank(); + bool else_has_rank = else_tensor.hasRank(); + if (then_has_rank && else_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + if (else_tensor.getRank() > 0) + data_first_dim = std::max( + static_cast(else_tensor.getShape().front()), data_first_dim); + } else if (then_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + } else if (else_has_rank) { + data_rank = else_tensor.getRank(); + if (else_tensor.getRank() > 0) + data_first_dim = else_tensor.getShape().front(); + } else { + // Neither has a rank. + return success(); + } + + auto cond_tensor = op.condition().getType().dyn_cast(); + if (!cond_tensor) return success(); + auto cond_rank = cond_tensor.getRank(); + // Check (2a) and (2b). + if (cond_rank == 0 || cond_rank == data_rank) return success(); + // Check (2c). + if (cond_rank == 1) { + auto cond_shape = cond_tensor.getShape().front(); + if (data_rank == 0) { + return op.emitOpError() + << "requires that t and e are nonscalar when pred is a vector"; + } + // We know `data` tensor has a rank of at least 1. + if (data_first_dim != -1 && cond_shape != -1 && + data_first_dim != cond_shape) { + return op.emitOpError() << "requires that, when pred is a vector, the " + "shape matches the first dimension of t and e"; + } + return success(); + } + // None of (2a,b,c) were true; fail. + return op.emitOpError() << "requires that pred is a scalar OR has the same " + "rank as t and e OR is a vector"; +} + +//===----------------------------------------------------------------------===// +// SelectV2Op +//===----------------------------------------------------------------------===// + +static Type InferSelectV2OpType(Value condition, Value e, Value t) { + Type element_ty = e.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + + Type broadcasted_ty = + OpTrait::util::getBroadcastedType(e.getType(), t.getType()); + if (!broadcasted_ty) return unranked_ty; + + auto cond_ranked_ty = condition.getType().dyn_cast(); + auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast(); + if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; + + // Explicitly get broadcasted output type as element types of condition may + // not be same as the broadcated type's element type. + SmallVector result_shape; + if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(), + broadcasted_ranked_ty.getShape(), + result_shape)) + return unranked_ty; + return RankedTensorType::get(result_shape, element_ty); +} + +void SelectV2Op::build(OpBuilder &builder, OperationState &result, + Value condition, Value e, Value t) { + build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t); +} + +//===----------------------------------------------------------------------===// +// ShapeOp +//===----------------------------------------------------------------------===// + +namespace { +// Validates Shape/ShapeN/VariableShape operand and associated result types. +LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, + Type result_type, + int variadic_idx = -1) { + std::string variadic_idx_str = + variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str(); + + auto result_ranked_type = result_type.dyn_cast(); + if (!result_ranked_type) return success(); + if (result_ranked_type.getShape().size() != 1) + return op->emitOpError("requires 1D type for result") << variadic_idx_str; + + auto operand_ranked_type = operand_type.dyn_cast_or_null(); + if (operand_ranked_type) { + // The operand is a ranked tensor. + if (result_ranked_type.hasStaticShape() && + !operand_ranked_type.getShape().empty() && + result_ranked_type.getDimSize(0) != + operand_ranked_type.getShape().size()) + return op->emitOpError("requires dimension size of result") + << variadic_idx_str << " to match rank of operand" + << variadic_idx_str; + } else if (result_ranked_type.hasStaticShape()) { + // The operand is an unranked tensor, print a warning if the result + // is static. + // Note: We do not handle this situation as an error, this would be too + // restrictive due to incompleteness of shape inference at this point. + op->emitWarning("has static shape result") + << variadic_idx_str << " for unranked operand" << variadic_idx_str; + } + + Type element_type = result_ranked_type.getElementType(); + if (!element_type.isSignlessInteger(32) && + !element_type.isSignlessInteger(64)) + return op->emitOpError("requires int32 or int64 return type for result") + << variadic_idx_str; + + return success(); +} +} // anonymous namespace + +static LogicalResult Verify(ShapeOp op) { + return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType()); +} + +// Converts shape of the given type to attribute if it is of ranked tensor type. +// Returned attribute has integer elements of the given width. +static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { + auto ranked_ty = input_ty.dyn_cast(); + if (!ranked_ty || !ranked_ty.hasStaticShape()) return {}; + + auto shape = ranked_ty.getShape(); + int rank = shape.size(); + + SmallVector dimensions; + dimensions.reserve(rank); + for (int i = 0; i < rank; ++i) + dimensions.push_back(APInt(out_width, shape[i])); + + auto result_type = RankedTensorType::get( + {rank}, IntegerType::get(out_width, input_ty.getContext())); + return DenseElementsAttr::get(result_type, dimensions); +} + +OpFoldResult ShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + return ConvertShapeToAttr(getOperand().getType(), width); +} + +void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input, + BoolAttr use32Bit) { + auto rankedTensorType = input.getType().dyn_cast(); + int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; + auto out_type = use32Bit.getValue() ? builder.getIntegerType(32) + : builder.getIntegerType(64); + return ShapeOp::build(builder, result, + RankedTensorType::get({rank}, out_type), input); +} + +//===----------------------------------------------------------------------===// +// ShapeNOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ShapeNOp op) { + const size_t num_tensors = op.N(); + + if (op.getNumOperands() != num_tensors) + return op.emitOpError() << "requires " << num_tensors << " operand(s), got " + << op.getNumOperands() << " operand(s)"; + + if (op.getNumResults() != num_tensors) + return op.emitOpError() << "requires " << num_tensors << " result(s), got " + << op.getNumResults() << " result(s)"; + + for (auto i : llvm::seq(0, num_tensors)) { + auto verification = VerifyShapeOperandAndResult( + op, op.getOperand(i).getType(), op.getResult(i).getType(), i); + if (failed(verification)) return verification; + } + + return success(); +} + +LogicalResult ShapeNOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + if (getNumOperands() == 0) return success(); + int width = + getType(0).cast().getElementType().getIntOrFloatBitWidth(); + + for (Type input_ty : getOperandTypes()) { + OpFoldResult result = ConvertShapeToAttr(input_ty, width); + if (!result) return failure(); + + results.push_back(result); + } + return success(); +} + +// TODO(hinsu): Add canonicalization pattern for ShapeN ops that don't have all +// static input shapes. Replacing output values corresponding to static input +// types may enable optimizations in users of the values. + +//===----------------------------------------------------------------------===// +// SizeOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// * Input type, if is a ranked tensor, has at most INT32_MAX dimensions. +// +static LogicalResult Verify(SizeOp op) { + if (!HasRankAtMost(op.input(), std::numeric_limits::max())) + return op.emitOpError( + "requires ranked input tensor to be of rank INT32_MAX or less"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// SliceOp +//===----------------------------------------------------------------------===// + +// Verifies that: +// +// - operands begin and size are 1D with the same number of elements. +// - if the input is a ranked tensor, the rank of the input equals the number +// of elements in operands begin and size. +// - if begin are constants, that +// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i] +// - if begins aren't constant but the input is a ranked tensor, that +// size[i] <= input_ty.getShape()[i] +// +static LogicalResult Verify(SliceOp op) { + RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin()); + if (begin_ty && begin_ty.getRank() != 1) { + return op.emitOpError() << "requires begin operand to be 1D tensor"; + } + + RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size()); + if (size_ty && size_ty.getRank() != 1) { + return op.emitOpError() << "requires size operand to be 1D tensor"; + } + + if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() || + !size_ty.hasStaticShape()) + return success(); + + if (begin_ty.getNumElements() != size_ty.getNumElements()) { + return op.emitOpError() << "requires begin and size operands to have the" + " same number of elements"; + } + + auto input_ty = op.input().getType().dyn_cast(); + if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { + return op.emitOpError() << "requires number of elements in begin and size" + "are equal to input rank"; + } + + DenseIntElementsAttr begin_indices; + if (matchPattern(op.begin(), m_Constant(&begin_indices))) { + DenseIntElementsAttr slice_sizes; + bool constant_slice_sizes = + matchPattern(op.size(), m_Constant(&slice_sizes)); + int dim = 0; + for (const APInt &raw_begin_index : begin_indices.getValues()) { + int64_t begin_index = raw_begin_index.getSExtValue(); + int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1; + int64_t slice_size = constant_slice_sizes + ? slice_sizes.getValue(dim).getSExtValue() + : 0; + if (slice_size == -1 && input_size != -1) { + slice_size = input_size - begin_index; + } + if (begin_index < 0 || + (input_size != -1 && begin_index + slice_size > input_size)) { + return op.emitOpError() + << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di"; + } + ++dim; + } + } else if (input_ty) { + // If the inputs are ranked, we can do a few more sanity checks. + DenseIntElementsAttr slice_sizes; + if (matchPattern(op.size(), m_Constant(&slice_sizes))) { + auto input_shape = input_ty.getShape(); + for (int64_t i = 0; i < input_ty.getRank(); ++i) { + int64_t slice_size = slice_sizes.getValue(i).getInt(); + int64_t input_size = input_shape[i]; + if (slice_size != -1 && input_size != -1 && slice_size > input_size) { + return op.emitOpError() << "requires size[i] <= Di, even if begin[i] " + "is unknown at compile time"; + } + } + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// SoftmaxOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SoftmaxOp op) { + if (!HasRankAtLeast(op.logits(), 1)) { + return op.emitOpError("requires operand to have rank at least 1"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// SoftmaxCrossEntropyWithLogitsOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// * Input types are broadcast compatible and the broadcasted type has rank two. +// +static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { + auto broadcasted_ty = OpTrait::util::getBroadcastedType( + op.features().getType(), op.labels().getType()) + .dyn_cast_or_null(); + if (!broadcasted_ty || + (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) + return op.emitOpError( + "requires features and labels to be broadcast compatible to rank two"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// SparseSoftmaxCrossEntropyWithLogitsOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) { + if (!IsOfRankOrUnranked(op.features(), 2)) { + return op.emitOpError("requires features operand of rank two"); + } + if (!IsOfRankOrUnranked(op.labels(), 1)) { + return op.emitOpError("requires labels operand of rank one"); + } + auto features_ty = op.features().getType().dyn_cast(); + auto labels_ty = op.labels().getType().dyn_cast(); + if (features_ty && labels_ty) { + int64_t features_batches = features_ty.getDimSize(0); + int64_t labels_batches = labels_ty.getDimSize(0); + if (!ShapedType::isDynamic(features_batches) && + !ShapedType::isDynamic(labels_batches) && + features_batches != labels_batches) + return op.emitOpError( + "requires features and labels with matching first dimension"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// SplitOp +//===----------------------------------------------------------------------===// + +// Verifies the input and split dimension operands for tf.Split/tf.SplitV. +// Writes the split dimension's index (adjusted with input rank) via `dim_index` +// if it's a constant. +template +LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { + *dim_index = llvm::None; + + Value split_dim = op.split_dim(); + if (auto split_dim_type = split_dim.getType().dyn_cast()) + if (split_dim_type.getRank() != 0) + return op.emitOpError( + "split dimension should be an integer scalar tensor"); + + // We can perform further verification if the input tensor to be split has + // known rank and the split dimension tensor is a constant. + + auto input_type = op.value().getType().template dyn_cast(); + if (!input_type) return success(); + + int64_t input_rank = input_type.getRank(); + if (input_rank == 0) + return op.emitOpError("cannot split scalar input tensor"); + + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success(); + + int64_t index = (*split_dim_attr.begin()).getSExtValue(); + + if (index + input_rank < 0 || index >= input_rank) { + return op.emitOpError("split dimension must be in range [-") + << input_rank << ", " << input_rank << ")"; + } + + if (index < 0) index += input_rank; + *dim_index = index; + + return success(); +} + +static LogicalResult Verify(SplitOp op) { + Optional dim_index; + if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); + if (!dim_index) return success(); + + int64_t input_dim_size = + op.value().getType().cast().getDimSize(*dim_index); + if (input_dim_size == ShapedType::kDynamicSize) return success(); + + if (input_dim_size % op.getNumResults() != 0) + return op.emitOpError("dimension #") + << *dim_index << " not divisible by the number of result tensors"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SplitVOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SplitVOp op) { + auto split_sizes_type = + op.size_splits().getType().dyn_cast(); + if (!split_sizes_type) return success(); + + if (split_sizes_type.getRank() != 1 || + split_sizes_type.getDimSize(0) != op.getNumResults()) + return op.emitOpError("split sizes should be a 1D tensor of ") + << op.getNumResults() << " elements"; + + Optional dim_index = 0; + if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); + if (!dim_index) return success(); + + int64_t input_dim_size = + op.value().getType().cast().getDimSize(*dim_index); + if (input_dim_size == ShapedType::kDynamicSize) return success(); + + // If split sizes come from a constant, they must sum to the dimension size + // along split_dim, and we can have no more than one dynamic dimension. + DenseIntElementsAttr split_sizes_attr; + if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) + return success(); + + int64_t total_dim_size = 0; // Total dimension size assigned to splits + llvm::Optional dynamic_dim_index; + + SmallVector split_sizes; + split_sizes.reserve( + split_sizes_attr.getType().cast().getNumElements()); + + for (auto dim : llvm::enumerate(split_sizes_attr)) { + int64_t dim_val = dim.value().getSExtValue(); + split_sizes.push_back(dim_val); + if (dim_val == ShapedType::kDynamicSize) { + // We cannot have more than one dynamic dimension. + if (dynamic_dim_index) + return op.emitOpError( + "cannot have more than one dynamic dimension in split sizes"); + dynamic_dim_index = dim.index(); + } else { + total_dim_size += dim_val; + } + } + + if (!dynamic_dim_index && total_dim_size != input_dim_size) + return op.emitOpError( + "split sizes must sum up to the dimension size along split " + "dimension, found ") + << total_dim_size << " vs " << input_dim_size; + + if (dynamic_dim_index && total_dim_size > input_dim_size) + return op.emitOpError( + "split sizes must sum up to be less than or equal to the " + "dimension size along split dimension, found ") + << total_dim_size << " vs " << input_dim_size; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SquareOp +//===----------------------------------------------------------------------===// + +void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// SubOp +//===----------------------------------------------------------------------===// + +void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult SubOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// SumOp +//===----------------------------------------------------------------------===// + +void SumOp::build(OpBuilder &builder, OperationState &result, Value input, + Value reduction_indices, BoolAttr keep_dims) { + Type out_ty = + InferReductionOpType(input, reduction_indices, keep_dims, &builder); + build(builder, result, out_ty, input, reduction_indices, keep_dims); +} + +//===----------------------------------------------------------------------===// +// StridedSliceOp +//===----------------------------------------------------------------------===// + +// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to +// tf.SliceOp if both of the following are true: +// - All strides have a known value equal to 1 +// - No masks are set (or masks can be applied by transforming the inputs to +// Slice) + +// Verifies that, +// +// - begin, end and strides operands are 1D and they have the same number of +// elements. Here, the number of elements should be less than 32 to support +// 32-bit mask attributes. +// - None of the strides values are zero. +// - Ellipsis mask can have at most one bit set. + +template +static LogicalResult VerifyStridedSliceBase(OpTy op) { + // Expected size for operands begin, end and strides vector operands. + int64_t expected_size = -1; + + for (Value val : {op.begin(), op.end(), op.strides()}) { + auto operand_ty = val.getType().dyn_cast(); + if (!operand_ty || !operand_ty.hasStaticShape()) { + // TensorFlow constant ops may have non-static shape because the shape is + // not propagated during constant folding. If the defining op for this + // operand is a constant op, use the constant op's attribute to get the + // actual shape. + DenseIntElementsAttr attr; + if (!matchPattern(val, m_Constant(&attr))) continue; + operand_ty = attr.getType(); + } + + if (operand_ty.getRank() != 1) + return op.emitOpError() + << "requires begin, end and strides to be 1D tensors"; + + int64_t length = operand_ty.getDimSize(0); + if (length == -1) continue; + + if (expected_size == -1) { + // This op uses 32-bit masks. + if (length >= 32) + return op.emitOpError( + "requires begin, end and strides operands with less than 32 " + "elements"); + + expected_size = length; + } else if (length != expected_size) { + return op.emitOpError() << "requires begin, end and strides to have the " + "same number of elements"; + } + } + + // If strides are constants, verify that none of the element is zero. + DenseIntElementsAttr strides; + if (matchPattern(op.strides(), m_Constant(&strides))) { + if (llvm::is_contained(strides.getValues(), 0)) + return op.emitOpError("requires non-zero strides"); + } + + // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there + // exists only no more than one ellipsis. + uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue(); + if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask)) + return op.emitOpError("cannot have multiple ellipses"); + + return success(); +} + +// Clamps the given `val`: returns `low` if `val` is less than `low`; returns +// `high` if `high` is less than `val`; otherwise returns `val`. +template +constexpr const T &Clamp(const T &val, const T &low, const T &high) { + assert(!(high < low)); + return (val < low) ? low : (high < val) ? high : val; +} + +// Checks if the `index` bit of `val` is set. +template +constexpr bool IsSet(const T &val, unsigned index) { + return (val & (1 << index)) != 0; +} + +// Sets the `index` bit of `val`. +template +constexpr void Set(T &val, unsigned index) { + val |= (1 << index); +} + +// Unset the `index` bit of `val`. +template +constexpr void Unset(T &val, unsigned index) { + val &= ~(1 << index); +} + +// Copy the `src_index` bit of `src` to `dst_index` bit of `dst`. +template +constexpr void CopyBit(const T &src, unsigned src_index, T &dst, + unsigned dst_index) { + if (IsSet(src, src_index)) + Set(dst, dst_index); + else + Unset(dst, dst_index); +} + +// The sparse spec of strided slice does not correspond to the number of +// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2, +// 4, 8) would have dims = 2. +struct SparseSliceSpec { + int64_t dims; + int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; + const ArrayRef &begin; + const ArrayRef &end; + const ArrayRef &strides; +}; + +// The dense spec of strided slice is the canonicalized version of sparse spec. +// The number of dimensions of dense spec correspond to the number of dimensions +// in operand tensor. +struct DenseSliceSpec { + int64_t dims; + int32_t begin_mask, end_mask, shrink_axis_mask; + SmallVectorImpl &begin; + SmallVectorImpl &end; + SmallVectorImpl &strides; +}; + +// Make a sparse spec into a dense index spec. +// The sparse spec does not correspond to the number of dimensions +// Make a dense spec that corresponds to the number of dimensions +// +// For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then +// we need to produce the missing begin_mask, end_mask for the first two +// dimensions i.e. foo[:, :, 3:, 2]. +static void BuildDenseSliceSpec(const SparseSliceSpec &sparse, + DenseSliceSpec *dense) { + // Build expanded dense begin, end, strides, begin_mask, end_mask, and + // shrink_axis_mask. + dense->begin.resize(dense->dims); + dense->end.resize(dense->dims); + dense->strides.resize(dense->dims); + dense->begin_mask = 0; + dense->end_mask = 0; + dense->shrink_axis_mask = 0; + + // Count number of new_axis after ellipsis. This helps in calculating the + // number of dimensions ellipsis represents in the sparse spec. + bool ellipsis_seen = false; + int num_new_axis_after_ellipsis = 0; + for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) { + if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index)) + num_new_axis_after_ellipsis++; + if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true; + } + + int dense_index = 0; + for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) { + if (IsSet(sparse.new_axis_mask, sparse_index)) continue; + if (IsSet(sparse.ellipsis_mask, sparse_index)) { + auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) + + 1 + num_new_axis_after_ellipsis, + dense->dims); + // Expand ellipsis into the appropriate dense indices. From current index + // until next_index, all dimensions would have begin and end masks set and + // stride 1, i.e., get all elements in those dimensions. + for (; dense_index < next_index; ++dense_index) { + dense->begin[dense_index] = dense->end[dense_index] = 0; + dense->strides[dense_index] = 1; + Set(dense->begin_mask, dense_index); + Set(dense->end_mask, dense_index); + } + continue; + } + assert(dense_index < dense->dims); + // Copy over the sparse indices to dense indices if ellipsis_mask and + // new_axis_mask are not set. + dense->begin[dense_index] = sparse.begin[sparse_index]; + dense->end[dense_index] = sparse.end[sparse_index]; + dense->strides[dense_index] = sparse.strides[sparse_index]; + CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index); + CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index); + CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask, + dense_index); + dense_index++; + } +} + +// For the given `input_shape`, calculates the sliced shape using the given +// `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and +// `shrink_axis_mask` masks. Updates the result back to `input_shape`. If +// `shrink_axis_mask` is not zero, this function will not drop the corresponding +// dimensions in `input_shape`; it will turn them into 1s. At the same time, +// canonicalizes `begin`, `end`, and `strides. The calculation follows +// tf.StridedSlice op semantics. +static void CalculateSlicedShapeFromDenseIndices( + MutableArrayRef input_shape, int32_t begin_mask, int32_t end_mask, + int32_t shrink_axis_mask, MutableArrayRef begin, + MutableArrayRef end, MutableArrayRef stride) { + assert(input_shape.size() <= 32); // Only 32-bit masks are supported. + + // Make sure ranges' ranks are consistent with the input. + assert(input_shape.size() == begin.size()); + assert(input_shape.size() == end.size()); + assert(input_shape.size() == stride.size()); + + for (int i = 0, e = input_shape.size(); i < e; ++i) { + if (ShapedType::isDynamic(input_shape[i])) continue; + + int64_t dim_i = input_shape[i]; + int64_t begin_i = begin[i]; + int64_t end_i = end[i]; + int64_t stride_i = stride[i]; + + // [0]: mask for begin, [1]: mask for end + int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)}; + // [0]: bound for begin, [1]: bound for end + int64_t bounds[] = {stride_i > 0 ? 0 : -1, + stride_i > 0 ? dim_i : dim_i - 1}; + + // Canonicalizes the given range `point` (begin/end) according to the + // current dimension. `c` means case: 0 for begin, 1 for end. + auto canonicalize = [&](int64_t point, int c) { + if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1]; + + // Add dim as offset to negative range point. + point = point < 0 ? dim_i + point : point; + return Clamp(point, bounds[0], bounds[1]); + }; + + begin_i = canonicalize(begin_i, 0); + end_i = canonicalize(end_i, 1); + + int64_t interval_len = end_i - begin_i; + int64_t size_i = 0; + // If internal length is zero or has different sign from stride, it's a + // degenerated case: we are slicing nothing. Otherwise, calculate the sliced + // size. + if (interval_len != 0 && (interval_len < 0) == (stride_i < 0)) + size_i = (interval_len / stride_i) + (interval_len % stride_i != 0); + + begin[i] = begin_i; + if (IsSet(shrink_axis_mask, i)) { + // Shrink this dimension. It means we only take the element at begin_i. + input_shape[i] = 1; + end[i] = begin_i + 1; + stride[i] = 1; + } else { + input_shape[i] = size_i; + end[i] = end_i; + stride[i] = stride_i; + } + } +} + +// For the given `input_shape`, calculates the sliced shape using the given +// `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`, +// `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks. +// Updates the result back to `input_shape`. +static void CalculateSlicedShapeFromSparseIndices( + MutableArrayRef input_shape, ArrayRef sparse_begin, + ArrayRef sparse_end, ArrayRef sparse_strides, + int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, + int32_t new_axis_mask, int32_t shrink_axis_mask, + SmallVectorImpl *begin, SmallVectorImpl *end, + SmallVectorImpl *stride) { + int64_t num_sparse_indices = sparse_begin.size(); + SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask, + ellipsis_mask, new_axis_mask, shrink_axis_mask, + sparse_begin, sparse_end, sparse_strides}; + + // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is + // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields + // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...]. + if (sparse.ellipsis_mask == 0) { + Set(sparse.ellipsis_mask, sparse.dims); + sparse.dims++; + } + + int64_t dims = input_shape.size(); + DenseSliceSpec dense = {dims, + /*begin_mask = */ 0, + /*end_mask = */ 0, + /*shrink_axis_mask = */ 0, + *begin, + *end, + *stride}; + + BuildDenseSliceSpec(sparse, &dense); + CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask, + dense.end_mask, dense.shrink_axis_mask, + *begin, *end, *stride); +} + +bool StridedSliceOp::GetSlicedBoundRanges( + SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, + SmallVectorImpl *slice_stride) { + // TODO(hinsu): Support lowering for ops with dynamic begin and end values + // when it is possible to derive indices based on mask attributes. + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; + if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(end(), m_Constant(&sparse_end_attr)) || + !matchPattern(strides(), m_Constant(&sparse_strides_attr))) + return false; + + auto input_ty = this->input().getType().dyn_cast(); + if (!input_ty || !input_ty.hasStaticShape()) return false; + auto input_shape = llvm::to_vector<4>(input_ty.getShape()); + + SmallVector sparse_begin, sparse_end, sparse_strides; + + for (const APInt &index : sparse_begin_attr) + sparse_begin.push_back(index.getSExtValue()); + for (const APInt &index : sparse_end_attr) + sparse_end.push_back(index.getSExtValue()); + for (const APInt &stride : sparse_strides_attr) + sparse_strides.push_back(stride.getSExtValue()); + + CalculateSlicedShapeFromSparseIndices( + input_shape, sparse_begin, sparse_end, sparse_strides, + begin_mask().getZExtValue(), end_mask().getZExtValue(), + ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), + shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); + return true; +} + +//===----------------------------------------------------------------------===// +// StridedSliceGradOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(StridedSliceGradOp op) { + auto shape_type = op.shape().getType().dyn_cast(); + if (shape_type && shape_type.getRank() != 1) + return op.emitOpError("'shape' operand must be 1D tensor, but got ") + << shape_type.getRank() << "D tensor"; + + if (failed(VerifyStridedSliceBase(op))) return failure(); + + // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with + // the sliced type from StridedSlice. + + return success(); +} + +bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( + SmallVectorImpl *input_shape, + SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, + SmallVectorImpl *slice_stride) { + DenseIntElementsAttr shape_attr; + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; + if (!matchPattern(shape(), m_Constant(&shape_attr)) || + !matchPattern(begin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(end(), m_Constant(&sparse_end_attr)) || + !matchPattern(strides(), m_Constant(&sparse_strides_attr))) + return false; + + int rank = std::distance(shape_attr.begin(), shape_attr.end()); + + input_shape->clear(); + input_shape->reserve(rank); + for (const APInt &dim : shape_attr) + input_shape->push_back(dim.getSExtValue()); + + SmallVector sparse_begin, sparse_end, sparse_strides; + + for (const APInt &index : sparse_begin_attr) + sparse_begin.push_back(index.getSExtValue()); + for (const APInt &index : sparse_end_attr) + sparse_end.push_back(index.getSExtValue()); + for (const APInt &stride : sparse_strides_attr) + sparse_strides.push_back(stride.getSExtValue()); + + CalculateSlicedShapeFromSparseIndices( + *input_shape, sparse_begin, sparse_end, sparse_strides, + begin_mask().getZExtValue(), end_mask().getZExtValue(), + ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), + shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); + return true; +} + +//===----------------------------------------------------------------------===// +// TensorListReserveOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TensorListReserveOp op) { + if (!IsOfRankOrUnranked(op.element_shape(), 0) && + !IsOfRankOrUnranked(op.element_shape(), 1)) { + return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); + } + + if (!IsOfRankOrUnranked(op.num_elements(), 0)) { + return op.emitOpError("requires num_elements operand to be 0D tensor"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// TensorListElementShapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult TensorListElementShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + auto variant_type = + getElementTypeOrSelf(getOperand().getType()).cast(); + if (variant_type.getSubtypes().empty()) return {}; + return ConvertShapeToAttr(variant_type.getSubtypes()[0], width); +} + +//===----------------------------------------------------------------------===// +// TensorListStackOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TensorListStackOp op) { + if (!IsOfRankOrUnranked(op.element_shape(), 0) && + !IsOfRankOrUnranked(op.element_shape(), 1)) { + return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// TensorScatterUpdateOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TensorScatterUpdateOp op) { + if (!HasRankAtLeast(op.tensor(), 1)) + return op.emitOpError( + "requires tensor operand to have at least 1 dimension"); + if (!HasRankAtLeast(op.indices(), 1)) + return op.emitOpError( + "requires indices operand to have at least 1 dimension"); + if (!HasRankAtLeast(op.updates(), 1)) + return op.emitOpError( + "requires updates operand to have at least 1 dimension"); + + auto tensor_ty = op.tensor().getType().dyn_cast(); + auto indices_ty = op.indices().getType().dyn_cast(); + if (!tensor_ty || !indices_ty) return success(); + + int64_t num_index_dims = indices_ty.getShape().back(); + if (ShapedType::isDynamic(num_index_dims)) return success(); + + if (num_index_dims > tensor_ty.getRank()) + return op.emitOpError( + "requires tensor operand with rank greater than or equal to the " + "indices operand's last dimensions"); + return success(); +} + +//===----------------------------------------------------------------------===// +// TopKV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TopKV2Op op) { + if (!HasRankAtLeast(op.input(), 1)) + return op.emitOpError( + "requires input operand to have at least 1 dimension"); + + if (!IsOfRankOrUnranked(op.k(), 0)) + return op.emitOpError("requires k operand to be 0D tensor"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ToBoolOp +//===----------------------------------------------------------------------===// + +namespace { +// If the input to ToBoolOp is a `tensor`, then the ToBoolOp is an identity +// function and can be removed. +class ToBoolOfZeroDBoolTensor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ToBoolOp op, + PatternRewriter &rewriter) const override { + if (auto type = op.getOperand().getType().dyn_cast()) { + if (type.getRank() == 0 && type.getElementType().isInteger(1)) { + rewriter.replaceOp(op, op.getOperand()); + return success(); + } + } + return failure(); + } +}; +} // namespace + +void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TransposeOp op) { + // TODO(hinsu): Verify using a custom verifier that, + // * Transpose permutation is 1-D of size equal to the rank of the first + // input, if the shapes are partially known. Requires use of a more + // restrictive type than TF_Tensor. + // * Result shape dimensions are possible based on the input shape. + return success(); +} + +// TODO(jpienaar): perm could be optional too. +void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, + Value perm) { + auto x_type = x.getType().cast(); + // If value is unranked, then so is results. + if (!x_type.hasRank()) + return TransposeOp::build(builder, result, + UnrankedTensorType::get(x_type.getElementType()), + x, perm); + + // TODO(jpienaar): Handle unknown perm case. + + // TODO(jpienaar): Extract utility function. + auto etype = x_type.cast().getElementType(); + DenseIntElementsAttr attr_shape; + if (matchPattern(perm, m_Constant(&attr_shape))) { + llvm::SmallVector const_shape; + if (attr_shape.isSplat()) { + const_shape.assign( + attr_shape.getNumElements(), + x_type.getDimSize((*attr_shape.begin()).getSExtValue())); + } else { + const_shape.reserve(attr_shape.getNumElements()); + for (const auto &dim : attr_shape) + const_shape.push_back(x_type.getDimSize(dim.getSExtValue())); + } + return TransposeOp::build( + builder, result, RankedTensorType::get(const_shape, etype), x, perm); + } + return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x, + perm); +} + +namespace { + +OpFoldResult FoldIdentityTranspose(TransposeOp op) { + auto const_perm = dyn_cast_or_null(op.perm().getDefiningOp()); + if (!const_perm) return {}; + + auto const_value = const_perm.value(); + const auto elements = const_value.getValues(); + + for (auto it : llvm::enumerate(elements)) { + if (it.index() != it.value()) return {}; + } + + // TODO(jpienaar): Remove if/when we handle this more generally. + if (op.getType() != op.x().getType()) { + // If the types don't match then only fold if all the operands are in the TF + // dialect. + for (auto user : op.getOperation()->getUsers()) + if (user->getDialect() != op.getDialect()) return {}; + } + + return op.x(); +} + +OpFoldResult FoldCancellableTranspose(TransposeOp op) { + // Operand is a TransposeOp. + auto transpose = dyn_cast_or_null(op.x().getDefiningOp()); + if (!transpose) return {}; + + // Permutations defined by constant operations. + auto perm0 = dyn_cast_or_null(op.perm().getDefiningOp()); + auto perm1 = dyn_cast_or_null(transpose.perm().getDefiningOp()); + if (!perm0 || !perm1) return {}; + + // With permutation indices that cancel each other + auto perm0_value = perm0.value().cast(); + auto perm1_value = perm1.value().cast(); + if (!AreCancellablePermutations(perm0_value, perm1_value)) return {}; + + return transpose.x(); +} + +} // namespace + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + if (auto folded = FoldIdentityTranspose(*this)) return folded; + if (auto folded = FoldCancellableTranspose(*this)) return folded; + return {}; +} + +//===----------------------------------------------------------------------===// +// TruncateDivOp +//===----------------------------------------------------------------------===// + +void TruncateDivOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// UnpackOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(UnpackOp op) { + auto value_type = op.value().getType().dyn_cast(); + if (!value_type) return success(); + + int64_t value_rank = value_type.getRank(); + int64_t axis = op.axis().getSExtValue(); + if (axis < -value_rank || axis >= value_rank) + return op.emitOpError("axis attribute must be in the range of [-") + << value_rank << ", " << value_rank << ')'; + + axis = GetDimForAxis(axis, value_rank); + int64_t dim_size = value_type.getDimSize(axis); + if (ShapedType::isDynamic(dim_size)) return success(); + + if (dim_size != op.getNumResults()) + return op.emitOpError("result count must be equal to ") << dim_size; + + return success(); +} + +//===----------------------------------------------------------------------===// +// Unsorted segment reduction ops +//===----------------------------------------------------------------------===// + +template +static LogicalResult VerifyUnsortedSegmentReduction(Op op) { + if (!HasRankAtMost(op.num_segments(), 0)) + return op.emitOpError("number of segments should be a 0-D tensor"); + + auto data_type = op.data().getType().template dyn_cast(); + auto segment_ids_type = + op.segment_ids().getType().template dyn_cast(); + if (data_type && segment_ids_type) { + if (data_type.getRank() < segment_ids_type.getRank()) + return op.emitOpError( + "requires segment ids rank to be less than or equal to data's rank"); + + int index = 0; + for (auto shape_pair : + llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) { + int64_t segment_id_dim = std::get<0>(shape_pair); + int64_t data_dim = std::get<1>(shape_pair); + if (!ShapedType::isDynamic(segment_id_dim) && + !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim) + return op.emitOpError( + "requires segment ids shape to be a prefix of data shape, " + "but dimension #") + << index << " differs: " << segment_id_dim << " vs. " + << data_dim; + ++index; + } + } + + DenseIntElementsAttr num_segments_attr; + if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) { + int64_t num_segments = (*num_segments_attr.begin()).getSExtValue(); + if (num_segments < 0) + return op.emitOpError("num of segments cannot be negative"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// VarIsInitializedOp +//===----------------------------------------------------------------------===// + +namespace { + +/// Erase VarIsInitializedOp operations with no uses. This op has side effect on +/// resources (read-only), but can still be deleted if it has zero uses. +struct EraseDeadVarIsInitializedOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(VarIsInitializedOp op, + PatternRewriter &rewriter) const override { + if (!op.use_empty()) return failure(); + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace. + +void VarIsInitializedOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// VariableShapeOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(VariableShapeOp op) { + auto input_type = op.input().getType().cast(); + if (input_type.hasStaticShape() && input_type.getNumElements() != 1) + return op.emitOpError("requires input to have one resource"); + + auto resource_type = input_type.getElementType().cast(); + auto subtypes = resource_type.getSubtypes(); + switch (subtypes.size()) { + case 1: + return VerifyShapeOperandAndResult( + op, resource_type.getSubtypes().front(), op.getType()); + case 0: + return VerifyShapeOperandAndResult(op, Type(), op.getType()); + default: + return op.emitOpError( + "requires resource input type to have at most 1 subtype"); + } +} + +OpFoldResult VariableShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + auto resource_type = + getElementTypeOrSelf(getOperand().getType()).cast(); + if (resource_type.getSubtypes().empty()) return {}; + return ConvertShapeToAttr(resource_type.getSubtypes()[0], width); +} + +//===----------------------------------------------------------------------===// +// WhileOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(WhileOp op) { + auto module = op.getParentOfType(); + auto cond_fn = module.lookupSymbol(op.cond()); + auto body_fn = module.lookupSymbol(op.body()); + if (!cond_fn) { + return op.emitOpError("cond refers to an undefined function : ") + << op.cond(); + } + if (!body_fn) { + return op.emitOpError("body refers to an undefined function : ") + << op.body(); + } + + auto cond_fn_type = cond_fn.getType(); + auto body_fn_type = body_fn.getType(); + + // Verify that the cond function has exactly one result. + if (cond_fn_type.getNumResults() != 1) + return op.emitOpError("requires cond function to have exactly one result"); + + SmallVector operands(op.getOperandTypes()); + + // Collect all the type lists for the op so that different pairs of type lists + // can be compared for the compatibility. + constexpr int kNumTypeLists = 5; + const std::array>, kNumTypeLists> + type_lists = {{ + {"operand", operands}, + {"body function result", body_fn_type.getResults()}, + {"result", op.getResultTypes()}, + {"cond function input", cond_fn_type.getInputs()}, + {"body function input", body_fn_type.getInputs()}, + }}; + + // A pair of type lists should be cast compatible with each other if one is + // converted to the another for a function call or assignment or there is a + // common source of inputs for both. Therefore, the While op requires the + // following pairs of type lists to be cast compatible for the tensor_cast + // operation: + // + // * Operands and cond inputs to call the cond function before the + // first iteration. + // * Operands and body inputs to call the body function for the first + // iteration if the cond functions returns True or equivalent result. + // * Operands and results to assign cond function arguments to op results if + // the cond function returns False or equivalent result. + // * All three pairs using cond inputs, body inputs and results as operand is + // a common source for all three. + // * Body result and cond inputs to call the cond function for the subsequent + // iterations. Similarly, Body result should be compatible with body inputs + // and op results. + // + // Note that the operands and body results need not be compatible as they are + // never converted from one to the another nor there is a common source + // tensors. Compatibility requirement is not transitive. + + for (int i = 0; i < kNumTypeLists; ++i) { + // Skip the first pair as the While op operands and body function results + // does not need to be compatible with each other. + for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { + auto &a = type_lists[i]; + auto &b = type_lists[j]; + + int a_size = a.second.size(); + if (a_size != b.second.size()) + return op.emitOpError( + llvm::formatv("requires the number of {0}s to be equal to the " + "number of {1}s. Found {2} and {3}, respectively", + a.first, b.first, a_size, b.second.size())); + + for (int idx = 0; idx < a_size; ++idx) { + auto a_type = a.second[idx]; + auto b_type = b.second[idx]; + + if (!AreCastCompatible({a_type, b_type})) + return op.emitError(llvm::formatv( + "{0} type {1} is incompatible with {2} type {3} at index {4}", + a.first, a_type, b.first, b_type, idx)); + } + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// WhileRegionOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(WhileRegionOp op) { + // Verify that the condition generates a single tensor result. + YieldOp yield = cast(op.cond().front().getTerminator()); + if (yield.getNumOperands() != 1) + return op.emitOpError() + << "condition should have a single tensor result"; + + auto cond_type = yield.getOperand(0).getType().dyn_cast(); + if (!cond_type || !cond_type.getShape().equals({}) || + !cond_type.getElementType().isInteger(/*width=*/1)) + return op.emitOpError() + << "condition should have a single tensor result"; + + // The body result types should match while op result types. + if (failed(VerifyRegionResults(op, op.body(), "body"))) return failure(); + + // Both condition and body should have same number and type of operands as + // the WhileRegion inputs. + const int num_inputs = op.getNumOperands(); + auto block_inputs_match_op_inputs = [&](Region ®ion, + StringRef name) -> LogicalResult { + Block &block = region.front(); + if (block.getNumArguments() != num_inputs) + return op.emitOpError() + << name << " should have same number of inputs (" << num_inputs + << ") as " << WhileRegionOp::getOperationName() << " but has " + << block.getNumArguments() << " inputs"; + + for (auto types_idx : llvm::enumerate( + llvm::zip(op.getOperandTypes(), block.getArgumentTypes()))) { + auto op_input_type = std::get<0>(types_idx.value()); + auto block_input_type = std::get<1>(types_idx.value()); + if (!AreCastCompatible({block_input_type, op_input_type})) + return op.emitOpError(llvm::formatv( + "{0} input type {1} is incompatible with {2} " + "input type {3} at index {4}", + name, block_input_type, WhileRegionOp::getOperationName(), + op_input_type, types_idx.index())); + } + return success(); + }; + + if (failed(block_inputs_match_op_inputs(op.cond(), "condition")) || + failed(block_inputs_match_op_inputs(op.body(), "body"))) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// WhileRegionOp LoopLikeOpInterface +//===----------------------------------------------------------------------===// + +Region &WhileRegionOp::getLoopBody() { return body(); } + +bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) { + // If the Op defining the value exists and the defining op is outside the + // scope of this WhileRegion, then we can infer that its defined outside. + // The defining Op is outside the scope of this WhileRegion if this + // WhileRegionOp is not an ancestor of the defining op in the parent chain. + Operation *def_op = value.getDefiningOp(); + return def_op && !getOperation()->isAncestor(def_op); +} + +LogicalResult WhileRegionOp::moveOutOfLoop( + llvm::ArrayRef ops) { + // Move the hoisted value to just before the while. + Operation *while_op = this->getOperation(); + for (auto op : ops) op->moveBefore(while_op); + return success(); +} + +//===----------------------------------------------------------------------===// +// WhileRegionOp canonicalization +//===----------------------------------------------------------------------===// +namespace { +// Eliminate values that pass through the WhileRegionOp body. +struct WhileRegionEliminatePassThrough + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileRegionOp while_op, + PatternRewriter &rewriter) const override { + // Replace values that simply passthrough the body with extern values. The + // block arguments of body and while match and so the corresponding cond + // argument can be easily found. + int old_num_operands = while_op.getNumOperands(); + int new_num_operands = old_num_operands; + auto &body_block = while_op.body().front(); + auto &cond_block = while_op.cond().front(); + auto &yield = *body_block.getTerminator(); + + // Bit mask indicating which operands will be removed. + SmallVector removed_operand(old_num_operands, false); + + for (int op_idx : llvm::seq(0, old_num_operands)) { + auto body_arg = body_block.getArgument(op_idx); + if (body_arg == yield.getOperand(op_idx)) { + // Replace the use of the passthrough value with the while operand + // in the body and condition regions, as well as the while output (if + // type match) + // TODO(jurahul): Use PatternRewriter API for IR modification. + auto value = while_op.getOperand(op_idx); + if (body_arg.getType() == value.getType()) + body_arg.replaceAllUsesWith(value); + + auto cond_arg = cond_block.getArgument(op_idx); + if (cond_arg.getType() == value.getType()) + cond_arg.replaceAllUsesWith(value); + + auto result = while_op.getResult(op_idx); + if (result.getType() == value.getType()) + result.replaceAllUsesWith(value); + } + + // Now check if the operand is unused in both regions as well as the + // result. If so, mark it for removal. + if (body_block.getArgument(op_idx).use_empty() && + cond_block.getArgument(op_idx).use_empty() && + while_op.getResult(op_idx).use_empty()) { + removed_operand[op_idx] = true; + new_num_operands--; + } + } + + if (new_num_operands == old_num_operands) return failure(); + + // Compress the operands, region arguments, and outputs. + SmallVector new_while_operands; + SmallVector new_result_types; + new_while_operands.reserve(new_num_operands); + new_result_types.reserve(new_num_operands); + + // Build new operands and result type. + int next_idx = 0; + for (int op_idx : llvm::seq(0, old_num_operands)) { + if (removed_operand[op_idx]) continue; + new_while_operands.push_back(while_op.getOperand(op_idx)); + new_result_types.push_back(while_op.getResult(op_idx).getType()); + next_idx++; + } + + // Create the new while operation. + auto new_while_op = + rewriter.create(while_op.getLoc(), new_result_types, + new_while_operands, while_op.getAttrs()); + + // Move region bodies to the new while. + rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(), + new_while_op.cond().end()); + rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(), + new_while_op.body().end()); + + auto &new_cond_block = new_while_op.cond().front(); + auto &new_body_block = new_while_op.body().front(); + auto &new_yield = *new_body_block.getTerminator(); + + // Build a vector of new results. Also patch up the region bodies and yield. + SmallVector new_results; + next_idx = 0; + for (int op_idx : llvm::seq(0, old_num_operands)) { + if (removed_operand[op_idx]) { + new_cond_block.eraseArgument(next_idx); + new_body_block.eraseArgument(next_idx); + new_yield.eraseOperand(next_idx); + new_results.push_back(nullptr); + } else { + new_results.push_back(new_while_op.getResult(next_idx++)); + } + } + + rewriter.replaceOp(while_op, new_results); + return success(); + } +}; + +} // anonymous namespace + +void WhileRegionOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// XdivyOp +//===----------------------------------------------------------------------===// + +void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc" + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h new file mode 100644 index 00000000000..b6e9222a370 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +namespace mlir { +namespace TF { + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h.inc" + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc new file mode 100644 index 00000000000..e87cc494a4a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace mlir { +namespace TF { + +namespace { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" +} // namespace + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc" + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h new file mode 100644 index 00000000000..8586515edee --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +namespace mlir { +namespace TF { + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h.inc" + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_