Merge pull request #41841 from dnguyen28061:merge_summary
PiperOrigin-RevId: 328175218 Change-Id: I3e4aca437f68bc748595eff537799b3a809c5c24
This commit is contained in:
commit
b4b4161746
tensorflow
c/kernels
core
python
@ -53,6 +53,19 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "merge_summary_op",
|
||||
prefix = "merge_summary_op",
|
||||
deps = [
|
||||
"//tensorflow/c:kernels",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["bitcast"],
|
||||
deps = [
|
||||
@ -82,6 +95,15 @@ tf_gen_op_libs(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["merge_summary"],
|
||||
deps = [
|
||||
"//tensorflow/c:ops",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bitcast_op_test",
|
||||
srcs = ["bitcast_op_test.cc"],
|
||||
@ -146,6 +168,7 @@ filegroup(
|
||||
srcs = [
|
||||
"bitcast_op.cc",
|
||||
"histogram_summary_op.cc",
|
||||
"merge_summary_op.cc",
|
||||
"summary_op.cc",
|
||||
"tensor_shape_utils.cc",
|
||||
"tensor_shape_utils.h",
|
||||
@ -158,6 +181,7 @@ filegroup(
|
||||
srcs = [
|
||||
"ops/bitcast.cc",
|
||||
"ops/histogram_summary.cc",
|
||||
"ops/merge_summary.cc",
|
||||
"ops/summary.cc",
|
||||
],
|
||||
)
|
||||
|
123
tensorflow/c/kernels/merge_summary_op.cc
Normal file
123
tensorflow/c/kernels/merge_summary_op.cc
Normal file
@ -0,0 +1,123 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/c/kernels.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status
|
||||
struct TFTensorDeleter {
|
||||
void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
|
||||
};
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
|
||||
};
|
||||
|
||||
// Struct that wraps TF_Tensor and TF_Status to delete once out of scope
|
||||
using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
|
||||
using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
// dummy functions used for kernel registration
|
||||
void* MergeSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
|
||||
|
||||
void MergeSummaryOp_Delete(void* kernel) {}
|
||||
|
||||
void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
|
||||
tensorflow::Summary s;
|
||||
std::unordered_set<tensorflow::string> tags;
|
||||
Safe_TF_StatusPtr status(TF_NewStatus());
|
||||
for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) {
|
||||
TF_Tensor* input;
|
||||
TF_GetInput(ctx, input_num, &input, status.get());
|
||||
Safe_TF_TensorPtr safe_input_ptr(input);
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
TF_OpKernelContext_Failure(ctx, status.get());
|
||||
return;
|
||||
}
|
||||
auto tags_array =
|
||||
static_cast<tensorflow::tstring*>(TF_TensorData(safe_input_ptr.get()));
|
||||
for (int i = 0; i < TF_TensorElementCount(safe_input_ptr.get()); ++i) {
|
||||
const tensorflow::tstring& s_in = tags_array[i];
|
||||
tensorflow::Summary summary_in;
|
||||
if (!tensorflow::ParseProtoUnlimited(&summary_in, s_in)) {
|
||||
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT,
|
||||
"Could not parse one of the summary inputs");
|
||||
TF_OpKernelContext_Failure(ctx, status.get());
|
||||
return;
|
||||
}
|
||||
for (int v = 0; v < summary_in.value_size(); ++v) {
|
||||
// This tag is unused by the TensorSummary op, so no need to check for
|
||||
// duplicates.
|
||||
const tensorflow::string& tag = summary_in.value(v).tag();
|
||||
if ((!tag.empty()) && !tags.insert(tag).second) {
|
||||
std::ostringstream err;
|
||||
err << "Duplicate tag " << tag << " found in summary inputs ";
|
||||
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
|
||||
TF_OpKernelContext_Failure(ctx, status.get());
|
||||
return;
|
||||
}
|
||||
*s.add_value() = summary_in.value(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
|
||||
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
|
||||
/*dims=*/nullptr, /*num_dims=*/0,
|
||||
/*len=*/sizeof(tensorflow::tstring), status.get()));
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
TF_OpKernelContext_Failure(ctx, status.get());
|
||||
return;
|
||||
}
|
||||
tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
|
||||
TF_TensorData(summary_tensor.get()));
|
||||
CHECK(SerializeToTString(s, output_tstring));
|
||||
}
|
||||
|
||||
void RegisterMergeSummaryOpKernel() {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
{
|
||||
auto* builder = TF_NewKernelBuilder(
|
||||
"MergeSummary", tensorflow::DEVICE_CPU, &MergeSummaryOp_Create,
|
||||
&MergeSummaryOp_Compute, &MergeSummaryOp_Delete);
|
||||
TF_RegisterKernelBuilder("MergeSummary", builder, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status))
|
||||
<< "Error while registering Merge Summmary kernel";
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// A dummy static variable initialized by a lambda whose side-effect is to
|
||||
// register the Histogram Summary kernel.
|
||||
TF_ATTRIBUTE_UNUSED static bool IsMergeSummaryOpKernelRegistered = []() {
|
||||
if (SHOULD_REGISTER_OP_KERNEL("MergeSummary")) {
|
||||
RegisterMergeSummaryOpKernel();
|
||||
}
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace
|
51
tensorflow/c/kernels/ops/merge_summary.cc
Normal file
51
tensorflow/c/kernels/ops/merge_summary.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/ops.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
static void merge_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
|
||||
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
|
||||
TF_DeleteShapeHandle(result);
|
||||
}
|
||||
|
||||
void Register_MergeSummaryOp() {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TF_OpDefinitionBuilder* op_builder =
|
||||
TF_NewOpDefinitionBuilder("MergeSummary");
|
||||
TF_OpDefinitionBuilderAddInput(op_builder, "inputs: N * string");
|
||||
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
|
||||
TF_OpDefinitionBuilderAddAttr(op_builder, "N: int >= 1");
|
||||
TF_OpDefinitionBuilderSetShapeInferenceFunction(
|
||||
op_builder, &merge_summary_shape_inference_fn);
|
||||
|
||||
TF_RegisterOpDefinition(op_builder, status);
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK)
|
||||
<< "MergeSummary op registration failed: " << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_UNUSED static bool MergeSummaryOpRegistered = []() {
|
||||
if (SHOULD_REGISTER_OP("MergeSummary")) {
|
||||
Register_MergeSummaryOp();
|
||||
}
|
||||
return true;
|
||||
}();
|
@ -670,8 +670,9 @@ tf_gen_op_libs(
|
||||
":lib",
|
||||
":protos_all_cc",
|
||||
# TODO(b/162630222): remove this dependency.
|
||||
"//tensorflow/c/kernels:summary_op_lib",
|
||||
"//tensorflow/c/kernels:histogram_summary_op_lib",
|
||||
"//tensorflow/c/kernels:merge_summary_op_lib",
|
||||
"//tensorflow/c/kernels:summary_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -876,6 +877,7 @@ cc_library(
|
||||
":word2vec_ops",
|
||||
"//tensorflow/c/kernels:bitcast_op_lib",
|
||||
"//tensorflow/c/kernels:histogram_summary_op_lib",
|
||||
"//tensorflow/c/kernels:merge_summary_op_lib",
|
||||
"//tensorflow/c/kernels:summary_op_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op",
|
||||
] + if_chromiumos(
|
||||
@ -984,9 +986,10 @@ cc_library(
|
||||
name = "all_kernels_impl",
|
||||
visibility = [":__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/c/kernels:histogram_summary_op",
|
||||
"//tensorflow/c/kernels:summary_op",
|
||||
"//tensorflow/c/kernels:bitcast_op",
|
||||
"//tensorflow/c/kernels:histogram_summary_op",
|
||||
"//tensorflow/c/kernels:merge_summary_op",
|
||||
"//tensorflow/c/kernels:summary_op",
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:audio",
|
||||
"//tensorflow/core/kernels:batch_kernels",
|
||||
|
@ -3087,7 +3087,6 @@ cc_library(
|
||||
":logging_ops",
|
||||
":summary_audio_op",
|
||||
":summary_image_op",
|
||||
":summary_op",
|
||||
":summary_tensor_op",
|
||||
],
|
||||
)
|
||||
@ -3098,6 +3097,10 @@ LOGGING_DEPS = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
# TODO(b/162630222): remove this dependency.
|
||||
"//tensorflow/c/kernels:histogram_summary_op",
|
||||
"//tensorflow/c/kernels:merge_summary_op",
|
||||
"//tensorflow/c/kernels:summary_op",
|
||||
]
|
||||
|
||||
tf_kernel_library(
|
||||
@ -3118,12 +3121,12 @@ tf_kernel_library(
|
||||
deps = LOGGING_DEPS + ["//tensorflow/core:png_internal"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
# TODO(b/162630222): remove this target
|
||||
cc_library(
|
||||
name = "summary_op",
|
||||
prefix = "summary_op",
|
||||
deps = LOGGING_DEPS + [
|
||||
# TODO(b/162630222): remove these dependencies.
|
||||
deps = [
|
||||
"//tensorflow/c/kernels:histogram_summary_op",
|
||||
"//tensorflow/c/kernels:merge_summary_op",
|
||||
"//tensorflow/c/kernels:summary_op",
|
||||
],
|
||||
)
|
||||
@ -6217,7 +6220,6 @@ filegroup(
|
||||
"string_split_op.cc",
|
||||
"string_to_hash_bucket_op.cc",
|
||||
"substr_op.cc",
|
||||
"summary_op.cc",
|
||||
"tensor_array.cc",
|
||||
"tensor_array_ops.cc",
|
||||
"tensor_list.cc",
|
||||
|
@ -1,76 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
|
||||
// inputs or outputs in various ways.
|
||||
|
||||
// See docs in ../ops/summary_ops.cc.
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class SummaryMergeOp : public OpKernel {
|
||||
public:
|
||||
explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
Summary s;
|
||||
std::unordered_set<string> tags;
|
||||
for (int input_num = 0; input_num < c->num_inputs(); input_num++) {
|
||||
const Tensor& in = c->input(input_num);
|
||||
auto in_vec = in.flat<tstring>();
|
||||
for (int i = 0; i < in_vec.dimension(0); i++) {
|
||||
const string& s_in = in_vec(i);
|
||||
Summary summary_in;
|
||||
if (!ParseProtoUnlimited(&summary_in, s_in)) {
|
||||
c->SetStatus(errors::InvalidArgument(
|
||||
"Could not parse one of the summary inputs"));
|
||||
return;
|
||||
}
|
||||
|
||||
for (int v = 0; v < summary_in.value_size(); v++) {
|
||||
const string& tag = summary_in.value(v).tag();
|
||||
// The tag is unused by the TensorSummary op, so no need to check
|
||||
// for duplicates.
|
||||
if ((!tag.empty()) && !tags.insert(tag).second) {
|
||||
c->SetStatus(errors::InvalidArgument(strings::StrCat(
|
||||
"Duplicate tag ", tag, " found in summary inputs")));
|
||||
return;
|
||||
}
|
||||
*s.add_value() = summary_in.value(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor* summary_tensor = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
|
||||
CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MergeSummary").Device(DEVICE_CPU),
|
||||
SummaryMergeOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -116,12 +116,6 @@ REGISTER_OP("AudioSummary")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Deprecated(15, "Use AudioSummaryV2.");
|
||||
|
||||
REGISTER_OP("MergeSummary")
|
||||
.Input("inputs: N * string")
|
||||
.Output("summary: string")
|
||||
.Attr("N : int >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("Timestamp")
|
||||
.Output("ts: float64")
|
||||
.SetIsStateful()
|
||||
|
@ -2994,6 +2994,7 @@ tf_gen_op_wrapper_private_py(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/kernels:histogram_summary_op_lib",
|
||||
"//tensorflow/c/kernels:merge_summary_op_lib",
|
||||
"//tensorflow/c/kernels:summary_op_lib",
|
||||
"//tensorflow/core:logging_ops_op_lib",
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user