diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 199a9c0939c..1c740731acd 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -110,6 +110,8 @@ cc_library( deps = [ ":tensorflow_op_interfaces_inc_gen", ":tensorflow_structs", + "//tensorflow/core:framework", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], @@ -810,8 +812,8 @@ cc_library( ], deps = [ ":tensorflow", + ":tensorflow_op_interfaces", ":tensorflow_types", - "//tensorflow/core:framework", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc index 93a55cd9289..cdc9e33e368 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -24,7 +24,6 @@ limitations under the License. #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "mlir/Analysis/CallGraph.h" // from @llvm-project @@ -40,8 +39,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { @@ -228,51 +227,16 @@ BacktrackAnalysisInfo::BacktrackAnalysisInfo( backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result)); } -namespace { - -//===----------------------------------------------------------------------===// -// ResourceAliasAnalysisInfo helper functions. -//===----------------------------------------------------------------------===// - -constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; - -// Returns if a VarHandleOp is anonymous, which means it always creates a new -// variable. -bool IsResourceHandleAnonymous(VarHandleOp handle) { - return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME; -} - -// Returns a string unique identifier for a non-anonymous VarHandleOp. -std::string GetVarHandleStringId(VarHandleOp handle) { - auto device = handle.getAttrOfType("device"); - return llvm::join( - llvm::ArrayRef{ - handle.container(), handle.shared_name(), - device ? device.getValue() : llvm::StringRef()}, - "/"); -} - -// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always -// creates a new ID; otherwise, tries to reuse the existing ID for the -// referenced variable if it exists, or creates a new one if not. -int64_t GetOrCreateIdForVarHandle(VarHandleOp handle, int64_t& next_id, - llvm::StringMap& name_id_map) { - // Always create a new ID for anonymous handle. - if (IsResourceHandleAnonymous(handle)) return next_id++; - - auto name = GetVarHandleStringId(handle); - auto emplace_res = name_id_map.try_emplace(name, next_id); - // New ID created, increment next_id. - if (emplace_res.second) ++next_id; - return emplace_res.first->second; -} - -} // namespace - //===----------------------------------------------------------------------===// // ResourceAliasAnalysisInfo //===----------------------------------------------------------------------===// +namespace { + +constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; + +} // namespace + constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId; // Constructs the analysis info by analyzing the given function. @@ -338,13 +302,13 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( } }); - llvm::StringMap var_handle_name_id_map; + llvm::SmallDenseMap resource_handle_id_map; func_op.walk([&](Operation* op) { - if (auto var_handle = dyn_cast(op)) { - AddValueUniqueIDMapping( - var_handle.resource(), - GetOrCreateIdForVarHandle(var_handle, next_unique_id, - var_handle_name_id_map)); + if (auto resource_alloc = dyn_cast(op)) { + ResourceHandleValueAndId resource = + resource_alloc.GetResourceHandleValueAndId(resource_handle_id_map, + next_unique_id); + AddValueUniqueIDMapping(resource.value, resource.id); } else if (llvm::isa(op)) { for (auto result : filter_resources(op->getResults())) PropagateInputToOutput(op->getOperand(result.getResultNumber()), diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h index 1eb5c89f0fc..17b52c03b6b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h @@ -16,10 +16,17 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ +#include + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" +#include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { @@ -49,8 +56,80 @@ struct ContractionFusion { SmallVector additional_attributes; }; +//===----------------------------------------------------------------------===// +// TensorFlow Resource Handles. +//===----------------------------------------------------------------------===// + +inline bool IsResourceHandleAnonymous(StringRef name) { + return name == tensorflow::ResourceHandle::ANONYMOUS_NAME; +} + +// Helper struct representing an identifier for a resource handle. For resource +// handles created explicitly and shared across resource allocator ops, +// `container`, `name`, and `device` can be set. If an resource handle is tied +// to an instance of an operation (e.g. TensorFlow runtime operation caching), +// `op` can be set instead. +struct ResourceHandle { + ResourceHandle(StringRef container, StringRef name, StringRef device, + Operation* op) + : container(container), name(name), device(device), op(op) {} + + bool operator==(const ResourceHandle& rhs) const { + return container == rhs.container && name == rhs.name && + device == rhs.device && op == rhs.op; + } + + // Make ResourceHandle hashable. + friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle); + + std::string container; + std::string name; + std::string device; + Operation* op = nullptr; +}; + +// Make ResourceHandle hashable. +inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) { + return ::llvm::hash_combine(resource_handle.container, resource_handle.name, + resource_handle.device, resource_handle.op); +} + +// Helper struct holding a resource handle value and unique id associated to the +// resource handle. +struct ResourceHandleValueAndId { + ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {} + + Value value; + int64_t id = -1; +}; + #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc" } // namespace TF } // namespace mlir +namespace llvm { +template <> +struct DenseMapInfo { + static mlir::TF::ResourceHandle getEmptyKey() { + return {/*container=*/"", /*name=*/"", /*device=*/"", /*op=*/nullptr}; + } + + static mlir::TF::ResourceHandle getTombstoneKey() { + return {/*container=*/"", + /*name=*/tensorflow::ResourceHandle::ANONYMOUS_NAME, /*device=*/"", + /*op=*/nullptr}; + } + + static unsigned getHashValue( + const mlir::TF::ResourceHandle& resource_handle) { + return mlir::TF::hash_value(resource_handle); + } + + static bool isEqual(const mlir::TF::ResourceHandle& lhs, + const mlir::TF::ResourceHandle& rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td index 3c41c04a0d6..1ed30c89a77 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td @@ -125,4 +125,27 @@ def TF_ContractionFusableInterface : OpInterface<"ContractionFusableInterface"> ]; } +//===----------------------------------------------------------------------===// +// TensorFlow Resource Handle Interfaces. +//===----------------------------------------------------------------------===// + +def TF_ResourceHandleAllocatorInterface : OpInterface<"ResourceHandleAllocatorInterface"> { + let description = [{ + A resource handle allocator operation is one that creates a resource handle, + or looks up and reuses an existing resource handle. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{Returns the resource handle value and unique id associated with + the resource handle. If a resource handle is reused, then an + existing id will be returned.}], + /*retTy=*/"ResourceHandleValueAndId", + /*methodName=*/"GetResourceHandleValueAndId", + /*args=*/(ins "llvm::SmallDenseMap&":$resource_handle_id_map, + "int64_t&":$next_id) + >, + ]; +} + #endif // TF_OP_INTERFACES diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 544f07f7075..67ad0fc4e70 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -787,7 +787,7 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co let results = (outs); } -def TF_VarHandleOp : TF_Op<"VarHandleOp", []> { +def TF_VarHandleOp : TF_Op<"VarHandleOp", [TF_ResourceHandleAllocatorInterface]> { let summary = "Creates a handle to a Variable resource from its name."; let description = [{ @@ -816,6 +816,13 @@ Example: TF_DerivedOperandOrResultHandleTypeAttr<"resource">; TF_DerivedOperandOrResultHandleShapeAttr shape = TF_DerivedOperandOrResultHandleShapeAttr<"resource">; + + let extraClassDeclaration = [{ + // TF_ResourceHandleAllocatorInterface: + ResourceHandleValueAndId GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id); + }]; } // Multiple variadic operands with different sizes are not supported by the diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 519f7e9fcaf..8742a0e2b71 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -2388,6 +2389,27 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { return success(); } +//===----------------------------------------------------------------------===// +// VarHandleOp +//===----------------------------------------------------------------------===// + +ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + // Always create a new ID for anonymous handle. + if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++}; + + llvm::StringRef device; + if (auto device_attr = getAttrOfType("device")) + device = device_attr.getValue(); + + ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr); + auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id); + // New ID created, increment next_id. + if (emplace_res.second) ++next_id; + return {resource(), emplace_res.first->second}; +} + //===----------------------------------------------------------------------===// // VarIsInitializedOp //===----------------------------------------------------------------------===//