Merge pull request #41209 from dnguyen28061:summary_op

PiperOrigin-RevId: 324685033
Change-Id: I6364d4545366ddb37e238b76d0fd74ff5a29c6df
This commit is contained in:
TensorFlower Gardener 2020-08-03 14:44:58 -07:00
commit 0d85fa03ef
13 changed files with 630 additions and 158 deletions

View File

@ -24,6 +24,21 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "summary_op",
prefix = "summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/kernels:tensor_shape_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
)
tf_gen_op_libs(
op_lib_names = ["bitcast"],
deps = [
@ -35,6 +50,15 @@ tf_gen_op_libs(
],
)
tf_gen_op_libs(
op_lib_names = ["summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "bitcast_op_test",
srcs = ["bitcast_op_test.cc"],
@ -48,6 +72,45 @@ tf_cc_test(
],
)
tf_cc_test(
name = "summary_op_test",
srcs = ["summary_op_test.cc"],
deps = [
":summary_op",
"//tensorflow/c:kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "tensor_shape_utils",
srcs = ["tensor_shape_utils.cc"],
hdrs = ["tensor_shape_utils.h"],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/c:tf_tensor",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "tensor_shape_utils_test",
srcs = ["tensor_shape_utils_test.cc"],
deps = [
":tensor_shape_utils",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Changes to the Android srcs here should be replicated in
# tensorflow/contrib/makefile/tf_op_files.txt.
#
@ -59,11 +122,17 @@ filegroup(
name = "android_all_op_kernels",
srcs = [
"bitcast_op.cc",
"summary_op.cc",
"tensor_shape_utils.cc",
"tensor_shape_utils.h",
],
)
# LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt)
filegroup(
name = "android_all_ops",
srcs = ["ops/bitcast.cc"],
srcs = [
"ops/bitcast.cc",
"ops/summary.cc",
],
)

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void scalar_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_ScalarSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("ScalarSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "tags: string");
TF_OpDefinitionBuilderAddInput(op_builder, "values: T");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &scalar_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "ScalarSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool SummaryScalarOpRegistered = []() {
if (SHOULD_REGISTER_OP("ScalarSummary")) {
Register_ScalarSummaryOp();
}
return true;
}();

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
namespace {
// Struct that stores the status and TF_Tensor inputs to the opkernel.
// Used to delete tensor and status in its destructor upon kernel return.
struct Params {
TF_Tensor* tags;
TF_Tensor* values;
TF_Status* status;
explicit Params(TF_OpKernelContext* ctx)
: tags(nullptr), values(nullptr), status(nullptr) {
status = TF_NewStatus();
TF_GetInput(ctx, 0, &tags, status);
if (TF_GetCode(status) == TF_OK) {
TF_GetInput(ctx, 1, &values, status);
}
}
~Params() {
TF_DeleteStatus(status);
TF_DeleteTensor(tags);
TF_DeleteTensor(values);
}
};
// dummy functions used for kernel registration
void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
void ScalarSummaryOp_Delete(void* kernel) {}
// Helper functions for compute method
bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2);
// Returns a string representation of a single tag or empty string if there
// are multiple tags
std::string SingleTag(TF_Tensor* tags);
template <typename T>
void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
Params params(ctx);
if (TF_GetCode(params.status) != TF_OK) {
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
if (!IsSameSize(params.tags, params.values)) {
std::ostringstream err;
err << "tags and values are not the same shape: "
<< tensorflow::ShapeDebugString(params.tags)
<< " != " << tensorflow::ShapeDebugString(params.values)
<< SingleTag(params.tags);
TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str());
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
// Convert tags and values tensor to array to access elements by index
tensorflow::Summary s;
auto tags_array =
static_cast<tensorflow::tstring*>(TF_TensorData(params.tags));
auto values_array = static_cast<T*>(TF_TensorData(params.values));
// Copy tags and values into summary protobuf
for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) {
tensorflow::Summary::Value* v = s.add_value();
const tensorflow::tstring& Ttags_i = tags_array[i];
v->set_tag(Ttags_i.data(), Ttags_i.size());
v->set_simple_value(static_cast<float>(values_array[i]));
}
TF_Tensor* summary_tensor =
TF_AllocateOutput(ctx, 0, TF_ExpectedOutputDataType(ctx, 0), nullptr, 0,
sizeof(tensorflow::tstring), params.status);
if (TF_GetCode(params.status) != TF_OK) {
TF_DeleteTensor(summary_tensor);
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
tensorflow::tstring* output_tstring =
reinterpret_cast<tensorflow::tstring*>(TF_TensorData(summary_tensor));
CHECK(SerializeToTString(s, output_tstring));
TF_DeleteTensor(summary_tensor);
}
bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) {
if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) {
return false;
}
for (int d = 0; d < TF_NumDims(tensor1); d++) {
if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) {
return false;
}
}
return true;
}
std::string SingleTag(TF_Tensor* tags) {
if (TF_TensorElementCount(tags) == 1) {
const char* single_tag =
static_cast<tensorflow::tstring*>(TF_TensorData(tags))->c_str();
return tensorflow::strings::StrCat(" (tag '", single_tag, "')");
} else {
return "";
}
}
template <typename T>
void RegisterScalarSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"ScalarSummary", tensorflow::DEVICE_CPU, &ScalarSummaryOp_Create,
&ScalarSummaryOp_Compute<T>, &ScalarSummaryOp_Delete);
TF_KernelBuilder_TypeConstraint(
builder, "T",
static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
TF_RegisterKernelBuilder("ScalarSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Scalar Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the ScalarSummary kernel.
TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) {
RegisterScalarSummaryOpKernel<tensorflow::int64>();
RegisterScalarSummaryOpKernel<tensorflow::uint64>();
RegisterScalarSummaryOpKernel<tensorflow::int32>();
RegisterScalarSummaryOpKernel<tensorflow::uint32>();
RegisterScalarSummaryOpKernel<tensorflow::uint16>();
RegisterScalarSummaryOpKernel<tensorflow::int16>();
RegisterScalarSummaryOpKernel<tensorflow::int8>();
RegisterScalarSummaryOpKernel<tensorflow::uint8>();
RegisterScalarSummaryOpKernel<Eigen::half>();
RegisterScalarSummaryOpKernel<tensorflow::bfloat16>();
RegisterScalarSummaryOpKernel<float>();
RegisterScalarSummaryOpKernel<double>();
}
return true;
}();
} // namespace

View File

@ -0,0 +1,186 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace {
class DummyDevice : public DeviceBase {
public:
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
};
// Helper for comparing ouput and expected output
void ExpectSummaryMatches(const Summary& actual, const string& expected_str) {
Summary expected;
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected));
EXPECT_EQ(expected.DebugString(), actual.DebugString());
}
void TestScalarSummaryOp(Tensor* tags, Tensor* values, string expected_output,
error::Code expected_code) {
// Initialize node used to fetch OpKernel
Status status;
NodeDef def;
def.set_op("ScalarSummary");
def.set_device(DEVICE_CPU);
AttrValue valuesTypeAttr;
SetAttrValue(values->dtype(), &valuesTypeAttr);
(*def.mutable_attr())["T"] = valuesTypeAttr;
def.add_input(strings::StrCat("input1: ", DataTypeString(tags->dtype())));
def.add_input(strings::StrCat("input2: ", DataTypeString(values->dtype())));
std::unique_ptr<OpKernel> kernel =
CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status);
ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params;
DummyDevice dummy_device(nullptr);
params.device = &dummy_device;
params.op_kernel = kernel.get();
AllocatorAttributes alloc_attrs;
params.output_attr_array = &alloc_attrs;
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.emplace_back(tags);
inputs.emplace_back(values);
params.inputs = &inputs;
OpKernelContext ctx(&params, 1);
kernel->Compute(&ctx);
ASSERT_EQ(expected_code, ctx.status().code());
if (expected_code == error::OK) {
Summary summary;
ASSERT_TRUE(ParseProtoUnlimited(
&summary, ctx.mutable_output(0)->scalar<tstring>()()));
ExpectSummaryMatches(summary, expected_output);
} else {
EXPECT_TRUE(absl::StrContains(ctx.status().ToString(), expected_output))
<< ctx.status();
}
}
TEST(ScalarSummaryOpTest, SimpleFloat) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_FLOAT, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<float>()(0) = 1.0f;
values.vec<float>()(1) = -0.73f;
values.vec<float>()(2) = 10000.0f;
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -0.73}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, SimpleDouble) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_DOUBLE, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<double>()(0) = 1.0;
values.vec<double>()(1) = -0.73;
values.vec<double>()(2) = 10000.0;
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -0.73}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, SimpleHalf) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_HALF, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<Eigen::half>()(0) = Eigen::half(1.0);
values.vec<Eigen::half>()(1) = Eigen::half(-2.0);
values.vec<Eigen::half>()(2) = Eigen::half(10000.0);
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -2.0}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, Error_WrongDimsTags) {
Tensor tags(DT_STRING, {2, 1});
Tensor values(DT_FLOAT, {2});
tags.matrix<tstring>()(0, 0) = "tag1";
tags.matrix<tstring>()(1, 0) = "tag2";
values.vec<float>()(0) = 1.0f;
values.vec<float>()(1) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, Error_WrongValuesTags) {
Tensor tags(DT_STRING, {2});
Tensor values(DT_FLOAT, {2, 1});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
values.matrix<float>()(0, 0) = 1.0f;
values.matrix<float>()(1, 0) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, Error_WrongWithSingleTag) {
Tensor tags(DT_STRING, {1});
Tensor values(DT_FLOAT, {2, 1});
tags.vec<tstring>()(0) = "tag1";
values.matrix<float>()(0, 0) = 1.0f;
values.matrix<float>()(1, 0) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, IsRegistered) {
const OpRegistrationData* reg;
TF_CHECK_OK(OpRegistry::Global()->LookUp("ScalarSummary", &reg));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include <string>
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
namespace tensorflow {
std::string ShapeDebugString(TF_Tensor* tensor) {
// A TF_Tensor cannot have an unknown rank.
CHECK_GE(TF_NumDims(tensor), 0);
tensorflow::string s = "[";
for (int i = 0; i < TF_NumDims(tensor); ++i) {
if (i > 0) tensorflow::strings::StrAppend(&s, ",");
int64_t dim = TF_Dim(tensor, i);
// A TF_Tensor cannot have an unknown dimension.
CHECK_GE(dim, 0);
tensorflow::strings::StrAppend(&s, dim);
}
tensorflow::strings::StrAppend(&s, "]");
return s;
}
} // namespace tensorflow

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains shape utilities to be used by kernels and is not part of
// the C API. As such, it is subject to change at any time.
#ifndef TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_
#define TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_
#include <string>
#include "tensorflow/c/tf_tensor.h"
namespace tensorflow {
// The following are utils for the shape of a TF_Tensor type.
// These functions may later be subsumed by the methods for a
// TF_TensorShape type.
// Returns a string representation of the TF_Tensor shape.
std::string ShapeDebugString(TF_Tensor* tensor);
} // namespace tensorflow
#endif // TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// A wrapper that will automatically delete the allocated TF_Tensor
// once out of scope.
struct TF_TensorWrapper {
TF_Tensor* tf_tensor;
explicit TF_TensorWrapper(TF_Tensor* tensor) { tf_tensor = tensor; }
~TF_TensorWrapper() { TF_DeleteTensor(tf_tensor); }
};
void TestShapeMatch(TensorShape shape) {
Tensor tensor(DT_FLOAT, shape);
Status status;
TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, &status);
TF_TensorWrapper tensor_wrapper = TF_TensorWrapper(tf_tensor);
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_EQ(tensor.shape().DebugString(), ShapeDebugString(tf_tensor));
}
TEST(ShapeDebugString, RegularShape) { TestShapeMatch(TensorShape({5, 4, 7})); }
TEST(ShapeDebugString, ScalarShape) { TestShapeMatch(TensorShape({})); }
} // namespace
} // namespace tensorflow

View File

@ -630,7 +630,6 @@ tf_gen_op_libs(
"linalg_ops",
"list_ops",
"lookup_ops",
"logging_ops",
"manip_ops",
"math_ops",
"mkl_nn_ops",
@ -664,6 +663,19 @@ tf_gen_op_libs(
],
)
tf_gen_op_libs(
is_external = False,
op_lib_names = [
"logging_ops",
],
deps = [
":lib",
":protos_all_cc",
# TODO(b/162630222): remove this dependency.
"//tensorflow/c/kernels:summary_op_lib",
],
)
tf_gen_op_libs(
op_lib_names = [
"string_ops",
@ -863,6 +875,7 @@ cc_library(
":user_ops_op_lib",
":word2vec_ops",
"//tensorflow/c/kernels:bitcast_op_lib",
"//tensorflow/c/kernels:summary_op_lib",
"//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op",
] + if_chromiumos(
[],
@ -970,6 +983,7 @@ cc_library(
name = "all_kernels_impl",
visibility = [":__subpackages__"],
deps = [
"//tensorflow/c/kernels:summary_op",
"//tensorflow/c/kernels:bitcast_op",
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:audio",

View File

@ -3858,6 +3858,8 @@ LOGGING_DEPS = [
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
# TODO(b/162630222): remove this dependency.
"//tensorflow/c/kernels:summary_op",
]
tf_kernel_library(

View File

@ -31,47 +31,6 @@ limitations under the License.
namespace tensorflow {
template <typename T>
class SummaryScalarOp : public OpKernel {
public:
explicit SummaryScalarOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* c) override {
const Tensor& tags = c->input(0);
const Tensor& values = c->input(1);
OP_REQUIRES(
c,
tags.IsSameSize(values) || (TensorShapeUtils::IsScalar(tags.shape()) &&
TensorShapeUtils::IsScalar(values.shape())),
errors::InvalidArgument(
"tags and values not the same shape: ", tags.shape().DebugString(),
" != ", values.shape().DebugString(), SingleTag(tags)));
auto Ttags = tags.flat<tstring>();
auto Tvalues = values.flat<T>();
Summary s;
for (int i = 0; i < Ttags.size(); i++) {
Summary::Value* v = s.add_value();
const tstring& Ttags_i = Ttags(i);
v->set_tag(Ttags_i.data(), Ttags_i.size());
v->set_simple_value(float(Tvalues(i)));
}
Tensor* summary_tensor = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()()));
}
// If there's only one tag, include it in the error message
static string SingleTag(const Tensor& tags) {
if (tags.NumElements() == 1) {
return strings::StrCat(" (tag '", tags.flat<tstring>()(0), "')");
} else {
return "";
}
}
};
template <typename T>
class SummaryHistoOp : public OpKernel {
public:
@ -114,9 +73,6 @@ class SummaryHistoOp : public OpKernel {
};
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER( \
Name("ScalarSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SummaryScalarOp<T>); \
REGISTER_KERNEL_BUILDER( \
Name("HistogramSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SummaryHistoOp<T>);

View File

@ -45,111 +45,6 @@ static void EXPECT_SummaryMatches(const Summary& actual,
EXPECT_EQ(expected.DebugString(), actual.DebugString());
}
class SummaryScalarOpTest : public OpsTestBase {
protected:
void MakeOp(DataType dt) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScalarSummary")
.Input(FakeInput())
.Input(FakeInput(dt))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(SummaryScalarOpTest, SimpleFloat) {
MakeOp(DT_FLOAT);
// Feed and run
AddInputFromArray<tstring>(TensorShape({3}), {"tag1", "tag2", "tag3"});
AddInputFromArray<float>(TensorShape({3}), {1.0f, -0.73f, 10000.0f});
TF_ASSERT_OK(RunOpKernel());
// Check the output size.
Tensor* out_tensor = GetOutput(0);
ASSERT_EQ(0, out_tensor->dims());
Summary summary;
ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
EXPECT_SummaryMatches(summary, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -0.73 }
value { tag: 'tag3' simple_value: 10000.0 }
)");
}
TEST_F(SummaryScalarOpTest, SimpleDouble) {
MakeOp(DT_DOUBLE);
// Feed and run
AddInputFromArray<tstring>(TensorShape({3}), {"tag1", "tag2", "tag3"});
AddInputFromArray<double>(TensorShape({3}), {1.0, -0.73, 10000.0});
TF_ASSERT_OK(RunOpKernel());
// Check the output size.
Tensor* out_tensor = GetOutput(0);
ASSERT_EQ(0, out_tensor->dims());
Summary summary;
ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
EXPECT_SummaryMatches(summary, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -0.73 }
value { tag: 'tag3' simple_value: 10000.0 }
)");
}
TEST_F(SummaryScalarOpTest, SimpleHalf) {
MakeOp(DT_HALF);
// Feed and run
AddInputFromList<tstring>(TensorShape({3}), {"tag1", "tag2", "tag3"});
AddInputFromList<Eigen::half>(TensorShape({3}), {1.0, -2.0, 10000.0});
TF_ASSERT_OK(RunOpKernel());
// Check the output size.
Tensor* out_tensor = GetOutput(0);
ASSERT_EQ(0, out_tensor->dims());
Summary summary;
ParseProtoUnlimited(&summary, out_tensor->scalar<tstring>()());
EXPECT_SummaryMatches(summary, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -2.0 }
value { tag: 'tag3' simple_value: 10000.0 }
)");
}
TEST_F(SummaryScalarOpTest, Error_MismatchedSize) {
MakeOp(DT_FLOAT);
// Feed and run
AddInputFromArray<tstring>(TensorShape({2}), {"tag1", "tag2"});
AddInputFromArray<float>(TensorShape({3}), {1.0f, -0.73f, 10000.0f});
Status s = RunOpKernel();
EXPECT_TRUE(absl::StrContains(s.ToString(), "not the same shape")) << s;
}
TEST_F(SummaryScalarOpTest, Error_WrongDimsTags) {
MakeOp(DT_FLOAT);
// Feed and run
AddInputFromArray<tstring>(TensorShape({2, 1}), {"tag1", "tag2"});
AddInputFromArray<float>(TensorShape({2}), {1.0f, -0.73f});
Status s = RunOpKernel();
EXPECT_TRUE(
absl::StrContains(s.ToString(), "tags and values not the same shape"))
<< s;
}
TEST_F(SummaryScalarOpTest, Error_WrongDimsValues) {
MakeOp(DT_FLOAT);
// Feed and run
AddInputFromArray<tstring>(TensorShape({2}), {"tag1", "tag2"});
AddInputFromArray<float>(TensorShape({2, 1}), {1.0f, -0.73f});
Status s = RunOpKernel();
EXPECT_TRUE(
absl::StrContains(s.ToString(), "tags and values not the same shape"))
<< s;
}
// --------------------------------------------------------------------------
// SummaryHistoOp
// --------------------------------------------------------------------------

View File

@ -87,13 +87,6 @@ REGISTER_OP("TensorSummary")
.Attr("display_name: string = ''")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ScalarSummary")
.Input("tags: string")
.Input("values: T")
.Output("summary: string")
.Attr("T: realnumbertype")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("HistogramSummary")
.Input("tag: string")
.Input("values: T")

View File

@ -2903,6 +2903,10 @@ tf_gen_op_wrapper_private_py(
"//learning/brain/python/ops:__pkg__",
"//tensorflow/python/kernel_tests:__pkg__",
],
deps = [
"//tensorflow/c/kernels:summary_op_lib",
"//tensorflow/core:logging_ops_op_lib",
],
)
tf_gen_op_wrapper_private_py(