Add TF_ResourceHandleAllocatorInterface op interface to tf.SummaryWriter.

tf.VarHandleOp resource handle id assignment logic is refactored and reused.

PiperOrigin-RevId: 337541804
Change-Id: Ia2949a6417ccd584e3ba3c20e3c0446bc4a9dd6e
This commit is contained in:
Andy Ly 2020-10-16 11:13:17 -07:00 committed by TensorFlower Gardener
parent 52473a84f4
commit 1eb833a909
3 changed files with 53 additions and 13 deletions

View File

@ -1635,7 +1635,7 @@ event: A string containing a binary-encoded tf.Event proto.
let results = (outs);
}
def TF_SummaryWriterOp : TF_Op<"SummaryWriter", []> {
def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [TF_ResourceHandleAllocatorInterface]> {
let summary = "Returns a handle to be used to access a summary writer.";
let description = [{
@ -1653,6 +1653,13 @@ writer: the summary writer resource. Scalar handle.
let results = (outs
Res<TF_ResourceTensor, "", [TF_SummaryAlloc]>:$writer
);
let extraClassDeclaration = [{
// TF_ResourceHandleAllocatorInterface:
ResourceHandleValueAndId GetResourceHandleValueAndId(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id);
}];
}
def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> {

View File

@ -587,3 +587,31 @@ struct DropAttributes : public OpRewritePattern<Op> {
}
};
//===----------------------------------------------------------------------===//
// TF op helper functions for handling resource handles and ids.
//===----------------------------------------------------------------------===//
// Returns device of op if present. If op has no device set, an empty string ref
// is returned instead.
llvm::StringRef GetDeviceOrEmpty(Operation *op) {
if (auto device_attr = op->getAttrOfType<StringAttr>("device"))
return device_attr.getValue();
return llvm::StringRef();
}
// Returns resource handle value and id for resource op based on attributes. If
// a resource handle is anonymous, a new id is always returned.
ResourceHandleValueAndId GetResourceHandleValueAndIdBase(
llvm::StringRef container, llvm::StringRef shared_name,
llvm::StringRef device, Value resource,
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id) {
// Always create a new ID for anonymous handle.
if (IsResourceHandleAnonymous(shared_name)) return {resource, next_id++};
ResourceHandle handle(container, shared_name, device, /*op=*/nullptr);
auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id);
// New ID created, increment next_id.
if (emplace_res.second) ++next_id;
return {resource, emplace_res.first->second};
}

View File

@ -1929,6 +1929,19 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
return true;
}
//===----------------------------------------------------------------------===//
// SummaryWriterOp
//===----------------------------------------------------------------------===//
ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id) {
llvm::StringRef device = GetDeviceOrEmpty(getOperation());
return GetResourceHandleValueAndIdBase(container(), shared_name(), device,
writer(), resource_handle_id_map,
next_id);
}
//===----------------------------------------------------------------------===//
// TensorListReserveOp
//===----------------------------------------------------------------------===//
@ -2445,18 +2458,10 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId(
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
int64_t &next_id) {
// Always create a new ID for anonymous handle.
if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++};
llvm::StringRef device;
if (auto device_attr = getAttrOfType<StringAttr>("device"))
device = device_attr.getValue();
ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr);
auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id);
// New ID created, increment next_id.
if (emplace_res.second) ++next_id;
return {resource(), emplace_res.first->second};
llvm::StringRef device = GetDeviceOrEmpty(getOperation());
return GetResourceHandleValueAndIdBase(container(), shared_name(), device,
resource(), resource_handle_id_map,
next_id);
}
//===----------------------------------------------------------------------===//