Move unique id generation for tf.VarHandleOp to an op interface.
Other ops can create resource handles, and some parts for determining a unique resource handle can be reused. PiperOrigin-RevId: 336658650 Change-Id: Icbb08f28692c8d27e28f731df44c01894cc6f307
This commit is contained in:
parent
cb42c4dea1
commit
aa6ff638f6
@ -110,6 +110,8 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tensorflow_op_interfaces_inc_gen",
|
":tensorflow_op_interfaces_inc_gen",
|
||||||
":tensorflow_structs",
|
":tensorflow_structs",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
],
|
],
|
||||||
@ -810,8 +812,8 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":tensorflow",
|
":tensorflow",
|
||||||
|
":tensorflow_op_interfaces",
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SCCIterator.h"
|
#include "llvm/ADT/SCCIterator.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
#include "llvm/ADT/iterator_range.h"
|
#include "llvm/ADT/iterator_range.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
|
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
|
||||||
@ -40,8 +39,8 @@ limitations under the License.
|
|||||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TF {
|
namespace TF {
|
||||||
@ -228,51 +227,16 @@ BacktrackAnalysisInfo::BacktrackAnalysisInfo(
|
|||||||
backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
|
backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ResourceAliasAnalysisInfo helper functions.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
|
||||||
|
|
||||||
// Returns if a VarHandleOp is anonymous, which means it always creates a new
|
|
||||||
// variable.
|
|
||||||
bool IsResourceHandleAnonymous(VarHandleOp handle) {
|
|
||||||
return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a string unique identifier for a non-anonymous VarHandleOp.
|
|
||||||
std::string GetVarHandleStringId(VarHandleOp handle) {
|
|
||||||
auto device = handle.getAttrOfType<StringAttr>("device");
|
|
||||||
return llvm::join(
|
|
||||||
llvm::ArrayRef<llvm::StringRef>{
|
|
||||||
handle.container(), handle.shared_name(),
|
|
||||||
device ? device.getValue() : llvm::StringRef()},
|
|
||||||
"/");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always
|
|
||||||
// creates a new ID; otherwise, tries to reuse the existing ID for the
|
|
||||||
// referenced variable if it exists, or creates a new one if not.
|
|
||||||
int64_t GetOrCreateIdForVarHandle(VarHandleOp handle, int64_t& next_id,
|
|
||||||
llvm::StringMap<int64_t>& name_id_map) {
|
|
||||||
// Always create a new ID for anonymous handle.
|
|
||||||
if (IsResourceHandleAnonymous(handle)) return next_id++;
|
|
||||||
|
|
||||||
auto name = GetVarHandleStringId(handle);
|
|
||||||
auto emplace_res = name_id_map.try_emplace(name, next_id);
|
|
||||||
// New ID created, increment next_id.
|
|
||||||
if (emplace_res.second) ++next_id;
|
|
||||||
return emplace_res.first->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ResourceAliasAnalysisInfo
|
// ResourceAliasAnalysisInfo
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
|
constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
|
||||||
|
|
||||||
// Constructs the analysis info by analyzing the given function.
|
// Constructs the analysis info by analyzing the given function.
|
||||||
@ -338,13 +302,13 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
llvm::StringMap<int64_t> var_handle_name_id_map;
|
llvm::SmallDenseMap<ResourceHandle, int64_t> resource_handle_id_map;
|
||||||
func_op.walk([&](Operation* op) {
|
func_op.walk([&](Operation* op) {
|
||||||
if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
|
if (auto resource_alloc = dyn_cast<ResourceHandleAllocatorInterface>(op)) {
|
||||||
AddValueUniqueIDMapping(
|
ResourceHandleValueAndId resource =
|
||||||
var_handle.resource(),
|
resource_alloc.GetResourceHandleValueAndId(resource_handle_id_map,
|
||||||
GetOrCreateIdForVarHandle(var_handle, next_unique_id,
|
next_unique_id);
|
||||||
var_handle_name_id_map));
|
AddValueUniqueIDMapping(resource.value, resource.id);
|
||||||
} else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
|
} else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
|
||||||
for (auto result : filter_resources(op->getResults()))
|
for (auto result : filter_resources(op->getResults()))
|
||||||
PropagateInputToOutput(op->getOperand(result.getResultNumber()),
|
PropagateInputToOutput(op->getOperand(result.getResultNumber()),
|
||||||
|
@ -16,10 +16,17 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMapInfo.h"
|
||||||
|
#include "llvm/ADT/Hashing.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TF {
|
namespace TF {
|
||||||
@ -49,8 +56,80 @@ struct ContractionFusion {
|
|||||||
SmallVector<NamedAttribute, 4> additional_attributes;
|
SmallVector<NamedAttribute, 4> additional_attributes;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TensorFlow Resource Handles.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
inline bool IsResourceHandleAnonymous(StringRef name) {
|
||||||
|
return name == tensorflow::ResourceHandle::ANONYMOUS_NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper struct representing an identifier for a resource handle. For resource
|
||||||
|
// handles created explicitly and shared across resource allocator ops,
|
||||||
|
// `container`, `name`, and `device` can be set. If an resource handle is tied
|
||||||
|
// to an instance of an operation (e.g. TensorFlow runtime operation caching),
|
||||||
|
// `op` can be set instead.
|
||||||
|
struct ResourceHandle {
|
||||||
|
ResourceHandle(StringRef container, StringRef name, StringRef device,
|
||||||
|
Operation* op)
|
||||||
|
: container(container), name(name), device(device), op(op) {}
|
||||||
|
|
||||||
|
bool operator==(const ResourceHandle& rhs) const {
|
||||||
|
return container == rhs.container && name == rhs.name &&
|
||||||
|
device == rhs.device && op == rhs.op;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make ResourceHandle hashable.
|
||||||
|
friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
|
||||||
|
|
||||||
|
std::string container;
|
||||||
|
std::string name;
|
||||||
|
std::string device;
|
||||||
|
Operation* op = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Make ResourceHandle hashable.
|
||||||
|
inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) {
|
||||||
|
return ::llvm::hash_combine(resource_handle.container, resource_handle.name,
|
||||||
|
resource_handle.device, resource_handle.op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper struct holding a resource handle value and unique id associated to the
|
||||||
|
// resource handle.
|
||||||
|
struct ResourceHandleValueAndId {
|
||||||
|
ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {}
|
||||||
|
|
||||||
|
Value value;
|
||||||
|
int64_t id = -1;
|
||||||
|
};
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
template <>
|
||||||
|
struct DenseMapInfo<mlir::TF::ResourceHandle> {
|
||||||
|
static mlir::TF::ResourceHandle getEmptyKey() {
|
||||||
|
return {/*container=*/"", /*name=*/"", /*device=*/"", /*op=*/nullptr};
|
||||||
|
}
|
||||||
|
|
||||||
|
static mlir::TF::ResourceHandle getTombstoneKey() {
|
||||||
|
return {/*container=*/"",
|
||||||
|
/*name=*/tensorflow::ResourceHandle::ANONYMOUS_NAME, /*device=*/"",
|
||||||
|
/*op=*/nullptr};
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned getHashValue(
|
||||||
|
const mlir::TF::ResourceHandle& resource_handle) {
|
||||||
|
return mlir::TF::hash_value(resource_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const mlir::TF::ResourceHandle& lhs,
|
||||||
|
const mlir::TF::ResourceHandle& rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace llvm
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
|
||||||
|
@ -125,4 +125,27 @@ def TF_ContractionFusableInterface : OpInterface<"ContractionFusableInterface">
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TensorFlow Resource Handle Interfaces.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TF_ResourceHandleAllocatorInterface : OpInterface<"ResourceHandleAllocatorInterface"> {
|
||||||
|
let description = [{
|
||||||
|
A resource handle allocator operation is one that creates a resource handle,
|
||||||
|
or looks up and reuses an existing resource handle.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{Returns the resource handle value and unique id associated with
|
||||||
|
the resource handle. If a resource handle is reused, then an
|
||||||
|
existing id will be returned.}],
|
||||||
|
/*retTy=*/"ResourceHandleValueAndId",
|
||||||
|
/*methodName=*/"GetResourceHandleValueAndId",
|
||||||
|
/*args=*/(ins "llvm::SmallDenseMap<ResourceHandle, int64_t>&":$resource_handle_id_map,
|
||||||
|
"int64_t&":$next_id)
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TF_OP_INTERFACES
|
#endif // TF_OP_INTERFACES
|
||||||
|
@ -787,7 +787,7 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co
|
|||||||
let results = (outs);
|
let results = (outs);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_VarHandleOp : TF_Op<"VarHandleOp", []> {
|
def TF_VarHandleOp : TF_Op<"VarHandleOp", [TF_ResourceHandleAllocatorInterface]> {
|
||||||
let summary = "Creates a handle to a Variable resource from its name.";
|
let summary = "Creates a handle to a Variable resource from its name.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -816,6 +816,13 @@ Example:
|
|||||||
TF_DerivedOperandOrResultHandleTypeAttr<"resource">;
|
TF_DerivedOperandOrResultHandleTypeAttr<"resource">;
|
||||||
TF_DerivedOperandOrResultHandleShapeAttr shape =
|
TF_DerivedOperandOrResultHandleShapeAttr shape =
|
||||||
TF_DerivedOperandOrResultHandleShapeAttr<"resource">;
|
TF_DerivedOperandOrResultHandleShapeAttr<"resource">;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// TF_ResourceHandleAllocatorInterface:
|
||||||
|
ResourceHandleValueAndId GetResourceHandleValueAndId(
|
||||||
|
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
|
||||||
|
int64_t &next_id);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiple variadic operands with different sizes are not supported by the
|
// Multiple variadic operands with different sizes are not supported by the
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/APFloat.h"
|
#include "llvm/ADT/APFloat.h"
|
||||||
#include "llvm/ADT/APInt.h"
|
#include "llvm/ADT/APInt.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/Optional.h"
|
#include "llvm/ADT/Optional.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/Sequence.h"
|
#include "llvm/ADT/Sequence.h"
|
||||||
@ -2388,6 +2389,27 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// VarHandleOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId(
|
||||||
|
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
|
||||||
|
int64_t &next_id) {
|
||||||
|
// Always create a new ID for anonymous handle.
|
||||||
|
if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++};
|
||||||
|
|
||||||
|
llvm::StringRef device;
|
||||||
|
if (auto device_attr = getAttrOfType<StringAttr>("device"))
|
||||||
|
device = device_attr.getValue();
|
||||||
|
|
||||||
|
ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr);
|
||||||
|
auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id);
|
||||||
|
// New ID created, increment next_id.
|
||||||
|
if (emplace_res.second) ++next_id;
|
||||||
|
return {resource(), emplace_res.first->second};
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// VarIsInitializedOp
|
// VarIsInitializedOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
Loading…
Reference in New Issue
Block a user