diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index bf4c8a2a135..c814153eb43 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -1635,7 +1635,7 @@ event: A string containing a binary-encoded tf.Event proto. let results = (outs); } -def TF_SummaryWriterOp : TF_Op<"SummaryWriter", []> { +def TF_SummaryWriterOp : TF_Op<"SummaryWriter", [TF_ResourceHandleAllocatorInterface]> { let summary = "Returns a handle to be used to access a summary writer."; let description = [{ @@ -1653,6 +1653,13 @@ writer: the summary writer resource. Scalar handle. let results = (outs Res:$writer ); + + let extraClassDeclaration = [{ + // TF_ResourceHandleAllocatorInterface: + ResourceHandleValueAndId GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id); + }]; } def TF_WriteAudioSummaryOp : TF_Op<"WriteAudioSummary", []> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc index 44df2b12d88..72ca50b5c37 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -587,3 +587,31 @@ struct DropAttributes : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// TF op helper functions for handling resource handles and ids. +//===----------------------------------------------------------------------===// + +// Returns device of op if present. If op has no device set, an empty string ref +// is returned instead. +llvm::StringRef GetDeviceOrEmpty(Operation *op) { + if (auto device_attr = op->getAttrOfType("device")) + return device_attr.getValue(); + return llvm::StringRef(); +} + +// Returns resource handle value and id for resource op based on attributes. If +// a resource handle is anonymous, a new id is always returned. +ResourceHandleValueAndId GetResourceHandleValueAndIdBase( + llvm::StringRef container, llvm::StringRef shared_name, + llvm::StringRef device, Value resource, + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + // Always create a new ID for anonymous handle. + if (IsResourceHandleAnonymous(shared_name)) return {resource, next_id++}; + + ResourceHandle handle(container, shared_name, device, /*op=*/nullptr); + auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id); + // New ID created, increment next_id. + if (emplace_res.second) ++next_id; + return {resource, emplace_res.first->second}; +} diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index e9ccbed53db..b99c99029ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -1929,6 +1929,19 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( return true; } +//===----------------------------------------------------------------------===// +// SummaryWriterOp +//===----------------------------------------------------------------------===// + +ResourceHandleValueAndId SummaryWriterOp::GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + llvm::StringRef device = GetDeviceOrEmpty(getOperation()); + return GetResourceHandleValueAndIdBase(container(), shared_name(), device, + writer(), resource_handle_id_map, + next_id); +} + //===----------------------------------------------------------------------===// // TensorListReserveOp //===----------------------------------------------------------------------===// @@ -2445,18 +2458,10 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId( llvm::SmallDenseMap &resource_handle_id_map, int64_t &next_id) { - // Always create a new ID for anonymous handle. - if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++}; - - llvm::StringRef device; - if (auto device_attr = getAttrOfType("device")) - device = device_attr.getValue(); - - ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr); - auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id); - // New ID created, increment next_id. - if (emplace_res.second) ++next_id; - return {resource(), emplace_res.first->second}; + llvm::StringRef device = GetDeviceOrEmpty(getOperation()); + return GetResourceHandleValueAndIdBase(container(), shared_name(), device, + resource(), resource_handle_id_map, + next_id); } //===----------------------------------------------------------------------===//