Add TF_ResourceHandleAllocatorInterface op interface to tf.SummaryWriter.
tf.VarHandleOp resource handle id assignment logic is refactored and reused. PiperOrigin-RevId: 337541804 Change-Id: Ia2949a6417ccd584e3ba3c20e3c0446bc4a9dd6e
This commit is contained in:
parent
52473a84f4
commit
1eb833a909
@ -1635,7 +1635,7 @@ event: A string containing a binary-encoded tf.Event proto.
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_SummaryWriterOp : TF_Op<"SummaryWriter", []> {
|
||||
def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [TF_ResourceHandleAllocatorInterface]> {
|
||||
let summary = "Returns a handle to be used to access a summary writer.";
|
||||
|
||||
let description = [{
|
||||
@ -1653,6 +1653,13 @@ writer: the summary writer resource. Scalar handle.
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_SummaryAlloc]>:$writer
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_ResourceHandleAllocatorInterface:
|
||||
ResourceHandleValueAndId GetResourceHandleValueAndId(
|
||||
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
|
||||
int64_t &next_id);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> {
|
||||
|
@ -587,3 +587,31 @@ struct DropAttributes : public OpRewritePattern<Op> {
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF op helper functions for handling resource handles and ids.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Returns device of op if present. If op has no device set, an empty string ref
|
||||
// is returned instead.
|
||||
llvm::StringRef GetDeviceOrEmpty(Operation *op) {
|
||||
if (auto device_attr = op->getAttrOfType<StringAttr>("device"))
|
||||
return device_attr.getValue();
|
||||
return llvm::StringRef();
|
||||
}
|
||||
|
||||
// Returns resource handle value and id for resource op based on attributes. If
|
||||
// a resource handle is anonymous, a new id is always returned.
|
||||
ResourceHandleValueAndId GetResourceHandleValueAndIdBase(
|
||||
llvm::StringRef container, llvm::StringRef shared_name,
|
||||
llvm::StringRef device, Value resource,
|
||||
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
|
||||
int64_t &next_id) {
|
||||
// Always create a new ID for anonymous handle.
|
||||
if (IsResourceHandleAnonymous(shared_name)) return {resource, next_id++};
|
||||
|
||||
ResourceHandle handle(container, shared_name, device, /*op=*/nullptr);
|
||||
auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id);
|
||||
// New ID created, increment next_id.
|
||||
if (emplace_res.second) ++next_id;
|
||||
return {resource, emplace_res.first->second};
|
||||
}
|
||||
|
@ -1929,6 +1929,19 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SummaryWriterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId(
|
||||
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
|
||||
int64_t &next_id) {
|
||||
llvm::StringRef device = GetDeviceOrEmpty(getOperation());
|
||||
return GetResourceHandleValueAndIdBase(container(), shared_name(), device,
|
||||
writer(), resource_handle_id_map,
|
||||
next_id);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorListReserveOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2445,18 +2458,10 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
|
||||
ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId(
|
||||
llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
|
||||
int64_t &next_id) {
|
||||
// Always create a new ID for anonymous handle.
|
||||
if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++};
|
||||
|
||||
llvm::StringRef device;
|
||||
if (auto device_attr = getAttrOfType<StringAttr>("device"))
|
||||
device = device_attr.getValue();
|
||||
|
||||
ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr);
|
||||
auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id);
|
||||
// New ID created, increment next_id.
|
||||
if (emplace_res.second) ++next_id;
|
||||
return {resource(), emplace_res.first->second};
|
||||
llvm::StringRef device = GetDeviceOrEmpty(getOperation());
|
||||
return GetResourceHandleValueAndIdBase(container(), shared_name(), device,
|
||||
resource(), resource_handle_id_map,
|
||||
next_id);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user