Merge pull request from dnguyen28061:merge_summary

PiperOrigin-RevId: 328175218
Change-Id: I3e4aca437f68bc748595eff537799b3a809c5c24
This commit is contained in:
TensorFlower Gardener 2020-08-24 11:20:27 -07:00
commit b4b4161746
8 changed files with 213 additions and 91 deletions

View File

@ -53,6 +53,19 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "merge_summary_op",
prefix = "merge_summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_gen_op_libs(
op_lib_names = ["bitcast"],
deps = [
@ -82,6 +95,15 @@ tf_gen_op_libs(
],
)
tf_gen_op_libs(
op_lib_names = ["merge_summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "bitcast_op_test",
srcs = ["bitcast_op_test.cc"],
@ -146,6 +168,7 @@ filegroup(
srcs = [
"bitcast_op.cc",
"histogram_summary_op.cc",
"merge_summary_op.cc",
"summary_op.cc",
"tensor_shape_utils.cc",
"tensor_shape_utils.h",
@ -158,6 +181,7 @@ filegroup(
srcs = [
"ops/bitcast.cc",
"ops/histogram_summary.cc",
"ops/merge_summary.cc",
"ops/summary.cc",
],
)

View File

@ -0,0 +1,123 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <sstream>
#include <unordered_set>
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
namespace {
// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status
struct TFTensorDeleter {
void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
};
struct TFStatusDeleter {
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
};
// Struct that wraps TF_Tensor and TF_Status to delete once out of scope
using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
// dummy functions used for kernel registration
void* MergeSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
void MergeSummaryOp_Delete(void* kernel) {}
void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
tensorflow::Summary s;
std::unordered_set<tensorflow::string> tags;
Safe_TF_StatusPtr status(TF_NewStatus());
for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) {
TF_Tensor* input;
TF_GetInput(ctx, input_num, &input, status.get());
Safe_TF_TensorPtr safe_input_ptr(input);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
auto tags_array =
static_cast<tensorflow::tstring*>(TF_TensorData(safe_input_ptr.get()));
for (int i = 0; i < TF_TensorElementCount(safe_input_ptr.get()); ++i) {
const tensorflow::tstring& s_in = tags_array[i];
tensorflow::Summary summary_in;
if (!tensorflow::ParseProtoUnlimited(&summary_in, s_in)) {
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT,
"Could not parse one of the summary inputs");
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
for (int v = 0; v < summary_in.value_size(); ++v) {
// This tag is unused by the TensorSummary op, so no need to check for
// duplicates.
const tensorflow::string& tag = summary_in.value(v).tag();
if ((!tag.empty()) && !tags.insert(tag).second) {
std::ostringstream err;
err << "Duplicate tag " << tag << " found in summary inputs ";
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
*s.add_value() = summary_in.value(v);
}
}
}
Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
/*dims=*/nullptr, /*num_dims=*/0,
/*len=*/sizeof(tensorflow::tstring), status.get()));
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
TF_TensorData(summary_tensor.get()));
CHECK(SerializeToTString(s, output_tstring));
}
void RegisterMergeSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"MergeSummary", tensorflow::DEVICE_CPU, &MergeSummaryOp_Create,
&MergeSummaryOp_Compute, &MergeSummaryOp_Delete);
TF_RegisterKernelBuilder("MergeSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Merge Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the Histogram Summary kernel.
TF_ATTRIBUTE_UNUSED static bool IsMergeSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("MergeSummary")) {
RegisterMergeSummaryOpKernel();
}
return true;
}();
} // namespace

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void merge_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_MergeSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("MergeSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "inputs: N * string");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "N: int >= 1");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &merge_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "MergeSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool MergeSummaryOpRegistered = []() {
if (SHOULD_REGISTER_OP("MergeSummary")) {
Register_MergeSummaryOp();
}
return true;
}();

View File

@ -670,8 +670,9 @@ tf_gen_op_libs(
":lib",
":protos_all_cc",
# TODO(b/162630222): remove this dependency.
"//tensorflow/c/kernels:summary_op_lib",
"//tensorflow/c/kernels:histogram_summary_op_lib",
"//tensorflow/c/kernels:merge_summary_op_lib",
"//tensorflow/c/kernels:summary_op_lib",
],
)
@ -876,6 +877,7 @@ cc_library(
":word2vec_ops",
"//tensorflow/c/kernels:bitcast_op_lib",
"//tensorflow/c/kernels:histogram_summary_op_lib",
"//tensorflow/c/kernels:merge_summary_op_lib",
"//tensorflow/c/kernels:summary_op_lib",
"//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op",
] + if_chromiumos(
@ -984,9 +986,10 @@ cc_library(
name = "all_kernels_impl",
visibility = [":__subpackages__"],
deps = [
"//tensorflow/c/kernels:histogram_summary_op",
"//tensorflow/c/kernels:summary_op",
"//tensorflow/c/kernels:bitcast_op",
"//tensorflow/c/kernels:histogram_summary_op",
"//tensorflow/c/kernels:merge_summary_op",
"//tensorflow/c/kernels:summary_op",
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:audio",
"//tensorflow/core/kernels:batch_kernels",

View File

@ -3087,7 +3087,6 @@ cc_library(
":logging_ops",
":summary_audio_op",
":summary_image_op",
":summary_op",
":summary_tensor_op",
],
)
@ -3098,6 +3097,10 @@ LOGGING_DEPS = [
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
# TODO(b/162630222): remove this dependency.
"//tensorflow/c/kernels:histogram_summary_op",
"//tensorflow/c/kernels:merge_summary_op",
"//tensorflow/c/kernels:summary_op",
]
tf_kernel_library(
@ -3118,12 +3121,12 @@ tf_kernel_library(
deps = LOGGING_DEPS + ["//tensorflow/core:png_internal"],
)
tf_kernel_library(
# TODO(b/162630222): remove this target
cc_library(
name = "summary_op",
prefix = "summary_op",
deps = LOGGING_DEPS + [
# TODO(b/162630222): remove these dependencies.
deps = [
"//tensorflow/c/kernels:histogram_summary_op",
"//tensorflow/c/kernels:merge_summary_op",
"//tensorflow/c/kernels:summary_op",
],
)
@ -6217,7 +6220,6 @@ filegroup(
"string_split_op.cc",
"string_to_hash_bucket_op.cc",
"substr_op.cc",
"summary_op.cc",
"tensor_array.cc",
"tensor_array_ops.cc",
"tensor_list.cc",

View File

@ -1,76 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
// See docs in ../ops/summary_ops.cc.
#include <unordered_set>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
class SummaryMergeOp : public OpKernel {
public:
explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* c) override {
Summary s;
std::unordered_set<string> tags;
for (int input_num = 0; input_num < c->num_inputs(); input_num++) {
const Tensor& in = c->input(input_num);
auto in_vec = in.flat<tstring>();
for (int i = 0; i < in_vec.dimension(0); i++) {
const string& s_in = in_vec(i);
Summary summary_in;
if (!ParseProtoUnlimited(&summary_in, s_in)) {
c->SetStatus(errors::InvalidArgument(
"Could not parse one of the summary inputs"));
return;
}
for (int v = 0; v < summary_in.value_size(); v++) {
const string& tag = summary_in.value(v).tag();
// The tag is unused by the TensorSummary op, so no need to check
// for duplicates.
if ((!tag.empty()) && !tags.insert(tag).second) {
c->SetStatus(errors::InvalidArgument(strings::StrCat(
"Duplicate tag ", tag, " found in summary inputs")));
return;
}
*s.add_value() = summary_in.value(v);
}
}
}
Tensor* summary_tensor = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()()));
}
};
REGISTER_KERNEL_BUILDER(Name("MergeSummary").Device(DEVICE_CPU),
SummaryMergeOp);
} // namespace tensorflow

View File

@ -116,12 +116,6 @@ REGISTER_OP("AudioSummary")
.SetShapeFn(shape_inference::ScalarShape)
.Deprecated(15, "Use AudioSummaryV2.");
REGISTER_OP("MergeSummary")
.Input("inputs: N * string")
.Output("summary: string")
.Attr("N : int >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("Timestamp")
.Output("ts: float64")
.SetIsStateful()

View File

@ -2994,6 +2994,7 @@ tf_gen_op_wrapper_private_py(
],
deps = [
"//tensorflow/c/kernels:histogram_summary_op_lib",
"//tensorflow/c/kernels:merge_summary_op_lib",
"//tensorflow/c/kernels:summary_op_lib",
"//tensorflow/core:logging_ops_op_lib",
],