Move unique id generation for tf.VarHandleOp to an op interface.
Other ops can create resource handles, and some parts for determining a unique resource handle can be reused. PiperOrigin-RevId: 336658650 Change-Id: Icbb08f28692c8d27e28f731df44c01894cc6f307
This commit is contained in:
parent
cb42c4dea1
commit
aa6ff638f6
@ -110,6 +110,8 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_op_interfaces_inc_gen",
|
||||
":tensorflow_structs",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
@ -810,8 +812,8 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":tensorflow_op_interfaces",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "llvm/ADT/SCCIterator.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
|
||||
@ -40,8 +39,8 @@ limitations under the License.
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
@ -228,51 +227,16 @@ BacktrackAnalysisInfo::BacktrackAnalysisInfo(
|
||||
backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceAliasAnalysisInfo helper functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
||||
|
||||
// Returns if a VarHandleOp is anonymous, which means it always creates a new
|
||||
// variable.
|
||||
bool IsResourceHandleAnonymous(VarHandleOp handle) {
|
||||
return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME;
|
||||
}
|
||||
|
||||
// Returns a string unique identifier for a non-anonymous VarHandleOp.
|
||||
std::string GetVarHandleStringId(VarHandleOp handle) {
|
||||
auto device = handle.getAttrOfType<StringAttr>("device");
|
||||
return llvm::join(
|
||||
llvm::ArrayRef<llvm::StringRef>{
|
||||
handle.container(), handle.shared_name(),
|
||||
device ? device.getValue() : llvm::StringRef()},
|
||||
"/");
|
||||
}
|
||||
|
||||
// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always
|
||||
// creates a new ID; otherwise, tries to reuse the existing ID for the
|
||||
// referenced variable if it exists, or creates a new one if not.
|
||||
int64_t GetOrCreateIdForVarHandle(VarHandleOp handle, int64_t& next_id,
|
||||
llvm::StringMap<int64_t>& name_id_map) {
|
||||
// Always create a new ID for anonymous handle.
|
||||
if (IsResourceHandleAnonymous(handle)) return next_id++;
|
||||
|
||||
auto name = GetVarHandleStringId(handle);
|
||||
auto emplace_res = name_id_map.try_emplace(name, next_id);
|
||||
// New ID created, increment next_id.
|
||||
if (emplace_res.second) ++next_id;
|
||||
return emplace_res.first->second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceAliasAnalysisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
||||
|
||||
} // namespace
|
||||
|
||||
constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
|
||||
|
||||
// Constructs the analysis info by analyzing the given function.
|
||||
@ -338,13 +302,13 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
|
||||
}
|
||||
});
|
||||
|
||||
llvm::StringMap<int64_t> var_handle_name_id_map;
|
||||
llvm::SmallDenseMap<ResourceHandle, int64_t> resource_handle_id_map;
|
||||
func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
|
||||
AddValueUniqueIDMapping(
|
||||
var_handle.resource(),
|
||||
GetOrCreateIdForVarHandle(var_handle, next_unique_id,
|
||||
var_handle_name_id_map));
|
||||
if (auto resource_alloc = dyn_cast<ResourceHandleAllocatorInterface>(op)) {
|
||||
ResourceHandleValueAndId resource =
|
||||
resource_alloc.GetResourceHandleValueAndId(resource_handle_id_map,
|
||||
next_unique_id);
|
||||
AddValueUniqueIDMapping(resource.value, resource.id);
|
||||
} else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
|
||||
for (auto result : filter_resources(op->getResults()))
|
||||
PropagateInputToOutput(op->getOperand(result.getResultNumber()),
|
||||
|
@ -16,10 +16,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
@ -49,8 +56,80 @@ struct ContractionFusion {
|
||||
SmallVector<NamedAttribute, 4> additional_attributes;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow Resource Handles.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline bool IsResourceHandleAnonymous(StringRef name) {
|
||||
return name == tensorflow::ResourceHandle::ANONYMOUS_NAME;
|
||||
}
|
||||
|
||||
// Helper struct representing an identifier for a resource handle. For resource
|
||||
// handles created explicitly and shared across resource allocator ops,
|
||||
// `container`, `name`, and `device` can be set. If an resource handle is tied
|
||||
// to an instance of an operation (e.g. TensorFlow runtime operation caching),
|
||||
// `op` can be set instead.
|
||||
struct ResourceHandle {
|
||||
ResourceHandle(StringRef container, StringRef name, StringRef device,
|
||||
Operation* op)
|
||||
: container(container), name(name), device(device), op(op) {}
|
||||
|
||||
bool operator==(const ResourceHandle& rhs) const {
|
||||
return container == rhs.container && name == rhs.name &&
|
||||
device == rhs.device && op == rhs.op;
|
||||
}
|
||||
|
||||
// Make ResourceHandle hashable.
|
||||
friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
|
||||
|
||||
std::string container;
|
||||
std::string name;
|
||||
std::string device;
|
||||
Operation* op = nullptr;
|
||||
};
|
||||
|
||||
// Make ResourceHandle hashable.
|
||||
inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) {
|
||||
return ::llvm::hash_combine(resource_handle.container, resource_handle.name,
|
||||
resource_handle.device, resource_handle.op);
|
||||
}
|
||||
|
||||
// Helper struct holding a resource handle value and unique id associated to the
|
||||
// resource handle.
|
||||
struct ResourceHandleValueAndId {
|
||||
ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {}
|
||||
|
||||
Value value;
|
||||
int64_t id = -1;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::TF::ResourceHandle> {
|
||||
static mlir::TF::ResourceHandle getEmptyKey() {
|
||||
return {/*container=*/"", /*name=*/"", /*device=*/"", /*op=*/nullptr};
|
||||
}
|
||||
|
||||
static mlir::TF::ResourceHandle getTombstoneKey() {
|
||||
return {/*container=*/"",
|
||||
/*name=*/tensorflow::ResourceHandle::ANONYMOUS_NAME, /*device=*/"",
|
||||
/*op=*/nullptr};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(
|
||||
const mlir::TF::ResourceHandle& resource_handle) {
|
||||
return mlir::TF::hash_value(resource_handle);
|
||||
}
|
||||
|
||||
static bool isEqual(const mlir::TF::ResourceHandle& lhs,
|
||||
const mlir::TF::ResourceHandle& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||
|
@ -125,4 +125,27 @@ def TF_ContractionFusableInterface : OpInterface<"ContractionFusableInterface">
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow Resource Handle Interfaces.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TF_ResourceHandleAllocatorInterface : OpInterface<"ResourceHandleAllocatorInterface"> {
|
||||
let description = [{
|
||||
A resource handle allocator operation is one that creates a resource handle,
|
||||
or looks up and reuses an existing resource handle.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{Returns the resource handle value and unique id associated with
|
||||
the resource handle. If a resource handle is reused, then an
|
||||
existing id will be returned.}],
|
||||
/*retTy=*/"ResourceHandleValueAndId",
|
||||
/*methodName=*/"GetResourceHandleValueAndId",
|
||||
/*args=*/(ins "llvm::SmallDenseMap<ResourceHandle, int64_t>&":$resource_handle_id_map,
|
||||
"int64_t&":$next_id)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TF_OP_INTERFACES
|
||||
|
@ -787,7 +787,7 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_VarHandleOp : TF_Op<"VarHandleOp", []> {
|
||||
def TF_VarHandleOp : TF_Op<"VarHandleOp", [TF_ResourceHandleAllocatorInterface]> {
|
||||
let summary = "Creates a handle to a Variable resource from its name.";
|
||||
|
||||
let description = [{
|
||||
@ -816,6 +816,13 @@ Example:
|
||||
TF_DerivedOperandOrResultHandleTypeAttr<"resource">;
|
||||
TF_DerivedOperandOrResultHandleShapeAttr shape =
|
||||
TF_DerivedOperandOrResultHandleShapeAttr<"resource">;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_ResourceHandleAllocatorInterface:
|
||||
ResourceHandleValueAndId GetResourceHandleValueAndId(
|
||||
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
|
||||
int64_t &next_id);
|
||||
}];
|
||||
}
|
||||
|
||||
// Multiple variadic operands with different sizes are not supported by the
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
@ -2388,6 +2389,27 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VarHandleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId(
|
||||
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
|
||||
int64_t &next_id) {
|
||||
// Always create a new ID for anonymous handle.
|
||||
if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++};
|
||||
|
||||
llvm::StringRef device;
|
||||
if (auto device_attr = getAttrOfType<StringAttr>("device"))
|
||||
device = device_attr.getValue();
|
||||
|
||||
ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr);
|
||||
auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id);
|
||||
// New ID created, increment next_id.
|
||||
if (emplace_res.second) ++next_id;
|
||||
return {resource(), emplace_res.first->second};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VarIsInitializedOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user