Merge pull request #41834 from dnguyen28061:histogram_summary

PiperOrigin-RevId: 326305643
Change-Id: I74cb683ccc2a040487d95322a366fc32f56d7bd1
This commit is contained in:
TensorFlower Gardener 2020-08-12 14:13:56 -07:00
commit f0030f31d1
8 changed files with 245 additions and 64 deletions

View File

@ -39,6 +39,20 @@ tf_kernel_library(
], ],
) )
tf_kernel_library(
name = "histogram_summary_op",
prefix = "histogram_summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
)
tf_gen_op_libs( tf_gen_op_libs(
op_lib_names = ["bitcast"], op_lib_names = ["bitcast"],
deps = [ deps = [
@ -59,6 +73,15 @@ tf_gen_op_libs(
], ],
) )
tf_gen_op_libs(
op_lib_names = ["histogram_summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_cc_test( tf_cc_test(
name = "bitcast_op_test", name = "bitcast_op_test",
srcs = ["bitcast_op_test.cc"], srcs = ["bitcast_op_test.cc"],
@ -122,6 +145,7 @@ filegroup(
name = "android_all_op_kernels", name = "android_all_op_kernels",
srcs = [ srcs = [
"bitcast_op.cc", "bitcast_op.cc",
"histogram_summary_op.cc",
"summary_op.cc", "summary_op.cc",
"tensor_shape_utils.cc", "tensor_shape_utils.cc",
"tensor_shape_utils.h", "tensor_shape_utils.h",
@ -133,6 +157,7 @@ filegroup(
name = "android_all_ops", name = "android_all_ops",
srcs = [ srcs = [
"ops/bitcast.cc", "ops/bitcast.cc",
"ops/histogram_summary.cc",
"ops/summary.cc", "ops/summary.cc",
], ],
) )

View File

@ -0,0 +1,163 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
namespace {
// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status.
struct TFTensorDeleter {
void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
};
struct TFStatusDeleter {
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
};
// Struct that wraps TF_Tensor and TF_Status to delete once out of scope.
using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
// Used to pass the operation node name from kernel construction to
// kernel computation.
struct HistogramSummaryOp {
std::string op_node_name;
};
void* HistogramSummaryOp_Create(TF_OpKernelConstruction* ctx) {
HistogramSummaryOp* kernel = new HistogramSummaryOp;
TF_StringView string_view_name = TF_OpKernelConstruction_GetName(ctx);
kernel->op_node_name =
std::string(string_view_name.data, string_view_name.len);
return kernel;
}
void HistogramSummaryOp_Delete(void* kernel) {
delete static_cast<HistogramSummaryOp*>(kernel);
}
template <typename T>
void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
HistogramSummaryOp* k = static_cast<HistogramSummaryOp*>(kernel);
TF_Tensor* tags;
TF_Tensor* values;
Safe_TF_StatusPtr status(TF_NewStatus());
TF_GetInput(ctx, 0, &tags, status.get());
Safe_TF_TensorPtr safe_tags_ptr(tags);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
TF_GetInput(ctx, 1, &values, status.get());
Safe_TF_TensorPtr safe_values_ptr(values);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
if (TF_NumDims(safe_tags_ptr.get()) != 0) {
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, "tags must be scalar");
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
// Cast values to array to access tensor elements by index
auto values_array = static_cast<T*>(TF_TensorData(safe_values_ptr.get()));
tensorflow::histogram::Histogram histo;
for (int64_t i = 0; i < TF_TensorElementCount(safe_values_ptr.get()); ++i) {
const double double_val = static_cast<double>(values_array[i]);
if (Eigen::numext::isnan(double_val)) {
std::ostringstream err;
err << "Nan in summary histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
return;
} else if (Eigen::numext::isinf(double_val)) {
std::ostringstream err;
err << "Infinity in Histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
return;
}
histo.Add(double_val);
}
tensorflow::Summary s;
tensorflow::Summary::Value* v = s.add_value();
const tensorflow::tstring& tag =
*(static_cast<tensorflow::tstring*>(TF_TensorData(safe_tags_ptr.get())));
v->set_tag(tag.data(), tag.size());
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
/*dims=*/nullptr, /*num_dims=*/0,
/*len=*/sizeof(tensorflow::tstring), status.get()));
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
TF_TensorData(summary_tensor.get()));
CHECK(SerializeToTString(s, output_tstring));
}
template <typename T>
void RegisterHistogramSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"HistogramSummary", tensorflow::DEVICE_CPU, &HistogramSummaryOp_Create,
&HistogramSummaryOp_Compute<T>, &HistogramSummaryOp_Delete);
TF_KernelBuilder_TypeConstraint(
builder, "T",
static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
TF_RegisterKernelBuilder("HistogramSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Histogram Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the Histogram Summary kernel.
TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary")) {
RegisterHistogramSummaryOpKernel<tensorflow::int64>();
RegisterHistogramSummaryOpKernel<tensorflow::uint64>();
RegisterHistogramSummaryOpKernel<tensorflow::int32>();
RegisterHistogramSummaryOpKernel<tensorflow::uint32>();
RegisterHistogramSummaryOpKernel<tensorflow::uint16>();
RegisterHistogramSummaryOpKernel<tensorflow::int16>();
RegisterHistogramSummaryOpKernel<tensorflow::int8>();
RegisterHistogramSummaryOpKernel<tensorflow::uint8>();
RegisterHistogramSummaryOpKernel<Eigen::half>();
RegisterHistogramSummaryOpKernel<tensorflow::bfloat16>();
RegisterHistogramSummaryOpKernel<float>();
RegisterHistogramSummaryOpKernel<double>();
}
return true;
}();
} // namespace

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void histogram_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_HistogramSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("HistogramSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "tag: string");
TF_OpDefinitionBuilderAddInput(op_builder, "values: T");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype = DT_FLOAT");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &histogram_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "HistogramSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool HistogramSummaryOpRegistered = []() {
if (SHOULD_REGISTER_OP("HistogramSummary")) {
Register_HistogramSummaryOp();
}
return true;
}();

View File

@ -674,6 +674,7 @@ tf_gen_op_libs(
":protos_all_cc", ":protos_all_cc",
# TODO(b/162630222): remove this dependency. # TODO(b/162630222): remove this dependency.
"//tensorflow/c/kernels:summary_op_lib", "//tensorflow/c/kernels:summary_op_lib",
"//tensorflow/c/kernels:histogram_summary_op_lib",
], ],
) )
@ -877,6 +878,7 @@ cc_library(
":user_ops_op_lib", ":user_ops_op_lib",
":word2vec_ops", ":word2vec_ops",
"//tensorflow/c/kernels:bitcast_op_lib", "//tensorflow/c/kernels:bitcast_op_lib",
"//tensorflow/c/kernels:histogram_summary_op_lib",
"//tensorflow/c/kernels:summary_op_lib", "//tensorflow/c/kernels:summary_op_lib",
"//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op", "//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op",
] + if_chromiumos( ] + if_chromiumos(
@ -985,6 +987,7 @@ cc_library(
name = "all_kernels_impl", name = "all_kernels_impl",
visibility = [":__subpackages__"], visibility = [":__subpackages__"],
deps = [ deps = [
"//tensorflow/c/kernels:histogram_summary_op",
"//tensorflow/c/kernels:summary_op", "//tensorflow/c/kernels:summary_op",
"//tensorflow/c/kernels:bitcast_op", "//tensorflow/c/kernels:bitcast_op",
"//tensorflow/core/kernels:array", "//tensorflow/core/kernels:array",

View File

@ -3120,7 +3120,9 @@ tf_kernel_library(
tf_kernel_library( tf_kernel_library(
name = "summary_op", name = "summary_op",
prefix = "summary_op", prefix = "summary_op",
deps = LOGGING_DEPS, deps = LOGGING_DEPS + [
"//tensorflow/c/kernels:histogram_summary_op",
],
) )
tf_kernel_library( tf_kernel_library(

View File

@ -31,62 +31,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
template <typename T>
class SummaryHistoOp : public OpKernel {
public:
// SummaryHistoOp could be extended to take a list of custom bucket
// boundaries as an option.
explicit SummaryHistoOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* c) override {
const Tensor& tags = c->input(0);
const Tensor& values = c->input(1);
const auto flat = values.flat<T>();
OP_REQUIRES(c, TensorShapeUtils::IsScalar(tags.shape()),
errors::InvalidArgument("tags must be scalar"));
// Build histogram of values in "values" tensor
histogram::Histogram histo;
for (int64 i = 0; i < flat.size(); i++) {
const double double_val = static_cast<double>(flat(i));
if (Eigen::numext::isnan(double_val)) {
c->SetStatus(
errors::InvalidArgument("Nan in summary histogram for: ", name()));
break;
} else if (Eigen::numext::isinf(double_val)) {
c->SetStatus(errors::InvalidArgument(
"Infinity in summary histogram for: ", name()));
break;
}
histo.Add(double_val);
}
Summary s;
Summary::Value* v = s.add_value();
const tstring& tags0 = tags.scalar<tstring>()();
v->set_tag(tags0.data(), tags0.size());
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
Tensor* summary_tensor = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()()));
}
};
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER( \
Name("HistogramSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SummaryHistoOp<T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER)
#undef REGISTER
struct HistogramResource : public ResourceBase {
histogram::ThreadSafeHistogram histogram;
string DebugString() const override {
return "A histogram summary. Stats ...";
}
};
class SummaryMergeOp : public OpKernel { class SummaryMergeOp : public OpKernel {
public: public:
explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {} explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {}

View File

@ -87,13 +87,6 @@ REGISTER_OP("TensorSummary")
.Attr("display_name: string = ''") .Attr("display_name: string = ''")
.SetShapeFn(shape_inference::ScalarShape); .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("HistogramSummary")
.Input("tag: string")
.Input("values: T")
.Output("summary: string")
.Attr("T: realnumbertype = DT_FLOAT")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ImageSummary") REGISTER_OP("ImageSummary")
.Input("tag: string") .Input("tag: string")
.Input("tensor: T") .Input("tensor: T")

View File

@ -2993,6 +2993,7 @@ tf_gen_op_wrapper_private_py(
"//tensorflow/python/kernel_tests:__pkg__", "//tensorflow/python/kernel_tests:__pkg__",
], ],
deps = [ deps = [
"//tensorflow/c/kernels:histogram_summary_op_lib",
"//tensorflow/c/kernels:summary_op_lib", "//tensorflow/c/kernels:summary_op_lib",
"//tensorflow/core:logging_ops_op_lib", "//tensorflow/core:logging_ops_op_lib",
], ],