Move unique id generation for tf.VarHandleOp to an op interface.

Other ops can create resource handles, and some parts for determining a unique resource handle can be reused.

PiperOrigin-RevId: 336658650
Change-Id: Icbb08f28692c8d27e28f731df44c01894cc6f307
This commit is contained in:
Andy Ly 2020-10-12 07:23:32 -07:00 committed by TensorFlower Gardener
parent cb42c4dea1
commit aa6ff638f6
6 changed files with 148 additions and 51 deletions

View File

@ -110,6 +110,8 @@ cc_library(
deps = [
":tensorflow_op_interfaces_inc_gen",
":tensorflow_structs",
"//tensorflow/core:framework",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
@ -810,8 +812,8 @@ cc_library(
],
deps = [
":tensorflow",
":tensorflow_op_interfaces",
":tensorflow_types",
"//tensorflow/core:framework",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
@ -40,8 +39,8 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/framework/resource_mgr.h"
namespace mlir {
namespace TF {
@ -228,51 +227,16 @@ BacktrackAnalysisInfo::BacktrackAnalysisInfo(
backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
}
namespace {
//===----------------------------------------------------------------------===//
// ResourceAliasAnalysisInfo helper functions.
//===----------------------------------------------------------------------===//
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
// Returns if a VarHandleOp is anonymous, which means it always creates a new
// variable.
bool IsResourceHandleAnonymous(VarHandleOp handle) {
return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME;
}
// Returns a string unique identifier for a non-anonymous VarHandleOp.
std::string GetVarHandleStringId(VarHandleOp handle) {
auto device = handle.getAttrOfType<StringAttr>("device");
return llvm::join(
llvm::ArrayRef<llvm::StringRef>{
handle.container(), handle.shared_name(),
device ? device.getValue() : llvm::StringRef()},
"/");
}
// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always
// creates a new ID; otherwise, tries to reuse the existing ID for the
// referenced variable if it exists, or creates a new one if not.
int64_t GetOrCreateIdForVarHandle(VarHandleOp handle, int64_t& next_id,
llvm::StringMap<int64_t>& name_id_map) {
// Always create a new ID for anonymous handle.
if (IsResourceHandleAnonymous(handle)) return next_id++;
auto name = GetVarHandleStringId(handle);
auto emplace_res = name_id_map.try_emplace(name, next_id);
// New ID created, increment next_id.
if (emplace_res.second) ++next_id;
return emplace_res.first->second;
}
} // namespace
//===----------------------------------------------------------------------===//
// ResourceAliasAnalysisInfo
//===----------------------------------------------------------------------===//
namespace {
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
} // namespace
constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
// Constructs the analysis info by analyzing the given function.
@ -338,13 +302,13 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
}
});
llvm::StringMap<int64_t> var_handle_name_id_map;
llvm::SmallDenseMap<ResourceHandle, int64_t> resource_handle_id_map;
func_op.walk([&](Operation* op) {
if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
AddValueUniqueIDMapping(
var_handle.resource(),
GetOrCreateIdForVarHandle(var_handle, next_unique_id,
var_handle_name_id_map));
if (auto resource_alloc = dyn_cast<ResourceHandleAllocatorInterface>(op)) {
ResourceHandleValueAndId resource =
resource_alloc.GetResourceHandleValueAndId(resource_handle_id_map,
next_unique_id);
AddValueUniqueIDMapping(resource.value, resource.id);
} else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
for (auto result : filter_resources(op->getResults()))
PropagateInputToOutput(op->getOperand(result.getResultNumber()),

View File

@ -16,10 +16,17 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
#include <string>
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
#include "tensorflow/core/framework/resource_mgr.h"
namespace mlir {
namespace TF {
@ -49,8 +56,80 @@ struct ContractionFusion {
SmallVector<NamedAttribute, 4> additional_attributes;
};
//===----------------------------------------------------------------------===//
// TensorFlow Resource Handles.
//===----------------------------------------------------------------------===//
inline bool IsResourceHandleAnonymous(StringRef name) {
return name == tensorflow::ResourceHandle::ANONYMOUS_NAME;
}
// Helper struct representing an identifier for a resource handle. For resource
// handles created explicitly and shared across resource allocator ops,
// `container`, `name`, and `device` can be set. If an resource handle is tied
// to an instance of an operation (e.g. TensorFlow runtime operation caching),
// `op` can be set instead.
struct ResourceHandle {
ResourceHandle(StringRef container, StringRef name, StringRef device,
Operation* op)
: container(container), name(name), device(device), op(op) {}
bool operator==(const ResourceHandle& rhs) const {
return container == rhs.container && name == rhs.name &&
device == rhs.device && op == rhs.op;
}
// Make ResourceHandle hashable.
friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
std::string container;
std::string name;
std::string device;
Operation* op = nullptr;
};
// Make ResourceHandle hashable.
inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) {
return ::llvm::hash_combine(resource_handle.container, resource_handle.name,
resource_handle.device, resource_handle.op);
}
// Helper struct holding a resource handle value and unique id associated to the
// resource handle.
struct ResourceHandleValueAndId {
ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {}
Value value;
int64_t id = -1;
};
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
} // namespace TF
} // namespace mlir
namespace llvm {
template <>
struct DenseMapInfo<mlir::TF::ResourceHandle> {
static mlir::TF::ResourceHandle getEmptyKey() {
return {/*container=*/"", /*name=*/"", /*device=*/"", /*op=*/nullptr};
}
static mlir::TF::ResourceHandle getTombstoneKey() {
return {/*container=*/"",
/*name=*/tensorflow::ResourceHandle::ANONYMOUS_NAME, /*device=*/"",
/*op=*/nullptr};
}
static unsigned getHashValue(
const mlir::TF::ResourceHandle& resource_handle) {
return mlir::TF::hash_value(resource_handle);
}
static bool isEqual(const mlir::TF::ResourceHandle& lhs,
const mlir::TF::ResourceHandle& rhs) {
return lhs == rhs;
}
};
} // namespace llvm
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_

View File

@ -125,4 +125,27 @@ def TF_ContractionFusableInterface : OpInterface<"ContractionFusableInterface">
];
}
//===----------------------------------------------------------------------===//
// TensorFlow Resource Handle Interfaces.
//===----------------------------------------------------------------------===//
def TF_ResourceHandleAllocatorInterface : OpInterface<"ResourceHandleAllocatorInterface"> {
let description = [{
A resource handle allocator operation is one that creates a resource handle,
or looks up and reuses an existing resource handle.
}];
let methods = [
InterfaceMethod<
/*desc=*/[{Returns the resource handle value and unique id associated with
the resource handle. If a resource handle is reused, then an
existing id will be returned.}],
/*retTy=*/"ResourceHandleValueAndId",
/*methodName=*/"GetResourceHandleValueAndId",
/*args=*/(ins "llvm::SmallDenseMap<ResourceHandle, int64_t>&":$resource_handle_id_map,
"int64_t&":$next_id)
>,
];
}
#endif // TF_OP_INTERFACES

View File

@ -787,7 +787,7 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co
let results = (outs);
}
def TF_VarHandleOp : TF_Op<"VarHandleOp", []> {
def TF_VarHandleOp : TF_Op<"VarHandleOp", [TF_ResourceHandleAllocatorInterface]> {
let summary = "Creates a handle to a Variable resource from its name.";
let description = [{
@ -816,6 +816,13 @@ Example:
TF_DerivedOperandOrResultHandleTypeAttr<"resource">;
TF_DerivedOperandOrResultHandleShapeAttr shape =
TF_DerivedOperandOrResultHandleShapeAttr<"resource">;
let extraClassDeclaration = [{
// TF_ResourceHandleAllocatorInterface:
ResourceHandleValueAndId GetResourceHandleValueAndId(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id);
}];
}
// Multiple variadic operands with different sizes are not supported by the

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
@ -2388,6 +2389,27 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
return success();
}
//===----------------------------------------------------------------------===//
// VarHandleOp
//===----------------------------------------------------------------------===//
ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id) {
// Always create a new ID for anonymous handle.
if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++};
llvm::StringRef device;
if (auto device_attr = getAttrOfType<StringAttr>("device"))
device = device_attr.getValue();
ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr);
auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id);
// New ID created, increment next_id.
if (emplace_res.second) ++next_id;
return {resource(), emplace_res.first->second};
}
//===----------------------------------------------------------------------===//
// VarIsInitializedOp
//===----------------------------------------------------------------------===//