Split tf_tensor_internal into its own header

This is one of a series of changes that allows users of TF_Tensor and
other C types to avoid depending on everything in TensorFlow.

PiperOrigin-RevId: 253337240
This commit is contained in:
James Ring 2019-06-14 19:20:22 -07:00 committed by TensorFlower Gardener
parent 09c5e94be3
commit 652d3e7bc6
9 changed files with 146 additions and 53 deletions

View File

@ -77,7 +77,10 @@ tf_cuda_library(
"//tensorflow/core:op_gen_lib", "//tensorflow/core:op_gen_lib",
"//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:server_lib",
], ],
}) + [":tf_status_internal"], }) + [
":tf_status_internal",
":tf_tensor_internal",
],
) )
cc_library( cc_library(
@ -211,9 +214,10 @@ cc_library(
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:android_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":c_api_internal",
":tf_datatype", ":tf_datatype",
":tf_status", ":tf_status",
":tf_status_helper",
":tf_tensor_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -221,6 +225,24 @@ cc_library(
}), }),
) )
tf_cuda_library(
name = "tf_tensor_internal",
hdrs = [
"tf_tensor.h",
"tf_tensor_internal.h",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_datatype",
":tf_status",
"//tensorflow/core:framework",
],
}),
)
tf_cuda_library( tf_cuda_library(
name = "c_api_experimental", name = "c_api_experimental",
srcs = [ srcs = [
@ -263,7 +285,7 @@ tf_cuda_library(
hdrs = ["tf_status_helper.h"], hdrs = ["tf_status_helper.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":c_api_no_xla", ":tf_status",
":tf_status_internal", ":tf_status_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
@ -329,17 +351,16 @@ tf_cuda_library(
], ],
copts = tf_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = [
":tf_status",
":tf_status_helper",
] + select({
"//tensorflow:android": [ "//tensorflow:android": [
":c_api_no_xla",
":c_api_internal", ":c_api_internal",
":tf_status_helper",
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:android_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":c_api_no_xla",
":c_api_internal", ":c_api_internal",
":tf_status_helper",
":tf_tensor", ":tf_tensor",
"//tensorflow/core:framework", "//tensorflow/core:framework",
], ],
@ -357,6 +378,8 @@ tf_cuda_library(
copts = tf_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":tf_datatype",
":tf_status",
":tf_status_helper", ":tf_status_helper",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
@ -365,7 +388,7 @@ tf_cuda_library(
"//conditions:default": [ "//conditions:default": [
"//tensorflow/core:framework", "//tensorflow/core:framework",
], ],
}) + [":c_api_internal"], }),
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@ -30,6 +30,7 @@ limitations under the License.
// clang-format on // clang-format on
#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/framework/op_gen_lib.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
@ -53,14 +54,6 @@ class ServerInterface;
// Internal structures used by the C API. These are likely to change and should // Internal structures used by the C API. These are likely to change and should
// not be depended on. // not be depended on.
struct TF_Tensor {
~TF_Tensor();
TF_DataType dtype;
tensorflow::TensorShape shape;
tensorflow::TensorBuffer* buffer;
};
struct TF_SessionOptions { struct TF_SessionOptions {
tensorflow::SessionOptions options; tensorflow::SessionOptions options;
}; };
@ -193,15 +186,6 @@ struct TF_Server {
namespace tensorflow { namespace tensorflow {
class TensorCApi {
public:
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
TensorBuffer* buf) {
return Tensor(static_cast<DataType>(type), shape, buf);
}
};
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);

View File

@ -16,12 +16,36 @@ limitations under the License.
#ifndef TENSORFLOW_C_KERNELS_H_ #ifndef TENSORFLOW_C_KERNELS_H_
#define TENSORFLOW_C_KERNELS_H_ #define TENSORFLOW_C_KERNELS_H_
#include "tensorflow/c/c_api.h" #include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
typedef struct TF_Tensor TF_Tensor;
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// C API for TensorFlow Kernels. // C API for TensorFlow Kernels.
// //

View File

@ -14,8 +14,12 @@ tf_kernel_library(
prefix = "bitcast_op", prefix = "bitcast_op",
deps = [ deps = [
"//tensorflow/c:kernels", "//tensorflow/c:kernels",
"//tensorflow/c:ops",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:ops", "//tensorflow/core:lib",
], ],
) )
@ -28,6 +32,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib",
], ],
) )

View File

@ -16,11 +16,12 @@ limitations under the License.
#include <sstream> #include <sstream>
#include "tensorflow/c/kernels.h" #include "tensorflow/c/kernels.h"
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.h"
// BitcastOp implements a bitcast kernel, creating an output tensor that shares // BitcastOp implements a bitcast kernel, creating an output tensor that shares
// the same data buffer as the input but with a different shape and/or data // the same data buffer as the input but with a different shape and/or data
@ -135,9 +136,8 @@ static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
TF_DeleteTensor(tensor); TF_DeleteTensor(tensor);
} }
static void RegisterBitcastOp() { void RegisterBitcastOp() {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
{ {
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU, auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
&BitcastOp_Create, &BitcastOp_Compute, &BitcastOp_Create, &BitcastOp_Compute,

View File

@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/c/ops.h" #include "tensorflow/c/ops.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"

View File

@ -73,7 +73,8 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include <stdlib.h> #include <stdlib.h>
#include "tensorflow/c/c_api.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#ifdef SWIG #ifdef SWIG
#define TF_CAPI_EXPORT #define TF_CAPI_EXPORT

View File

@ -15,10 +15,13 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/coding.h"
@ -227,13 +230,15 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
size_t dst_len, TF_Status* status) { size_t dst_len, TF_Status* status) {
const size_t sz = TF_StringEncodedSize(src_len); const size_t sz = TF_StringEncodedSize(src_len);
if (sz < src_len) { if (sz < src_len) {
status->status = InvalidArgument("src string is too large to encode"); Set_TF_Status_from_Status(
status, InvalidArgument("src string is too large to encode"));
return 0; return 0;
} }
if (dst_len < sz) { if (dst_len < sz) {
status->status = Set_TF_Status_from_Status(
status,
InvalidArgument("dst_len (", dst_len, ") too small to encode a ", InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
src_len, "-byte string"); src_len, "-byte string"));
return 0; return 0;
} }
dst = tensorflow::core::EncodeVarint64(dst, src_len); dst = tensorflow::core::EncodeVarint64(dst, src_len);
@ -259,7 +264,8 @@ static Status TF_StringDecode_Impl(const char* src, size_t src_len,
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
size_t* dst_len, TF_Status* status) { size_t* dst_len, TF_Status* status) {
status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len); Set_TF_Status_from_Status(status,
TF_StringDecode_Impl(src, src_len, dst, dst_len));
if (TF_GetCode(status) != TF_OK) return 0; if (TF_GetCode(status) != TF_OK) return 0;
return static_cast<size_t>(*dst - src) + *dst_len; return static_cast<size_t>(*dst - src) + *dst_len;
} }
@ -299,8 +305,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
TF_Status* status) { TF_Status* status) {
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
if (!src.IsInitialized()) { if (!src.IsInitialized()) {
status->status = FailedPrecondition( Set_TF_Status_from_Status(
"attempt to use a tensor with an uninitialized value"); status, FailedPrecondition(
"attempt to use a tensor with an uninitialized value"));
return nullptr; return nullptr;
} }
if (src.NumElements() == 0) { if (src.NumElements() == 0) {
@ -308,13 +315,14 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
} }
if (src.dtype() == tensorflow::DT_RESOURCE) { if (src.dtype() == tensorflow::DT_RESOURCE) {
if (src.shape().dims() != 0) { if (src.shape().dims() != 0) {
status->status = InvalidArgument( Set_TF_Status_from_Status(
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ", status, InvalidArgument(
src.shape().DebugString(), "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
"). Please file a bug at " src.shape().DebugString(),
"https://github.com/tensorflow/tensorflow/issues/new, " "). Please file a bug at "
"ideally with a " "https://github.com/tensorflow/tensorflow/issues/new, "
"short code snippet that reproduces this error."); "ideally with a "
"short code snippet that reproduces this error."));
return nullptr; return nullptr;
} }
const string str = const string str =
@ -353,9 +361,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
const string& s = srcarray(i); const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
if (TF_GetCode(status) != TF_OK) { if (TF_GetCode(status) != TF_OK) {
status->status = InvalidArgument( Set_TF_Status_from_Status(
"invalid string tensor encoding (string #", i, " of ", status,
srcarray.size(), "): ", status->status.error_message()); InvalidArgument("invalid string tensor encoding (string #", i, " of ",
srcarray.size(), "): ", TF_Message(status)));
delete[] base; delete[] base;
return nullptr; return nullptr;
} }
@ -363,9 +372,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
dst_len -= consumed; dst_len -= consumed;
} }
if (dst != base + size) { if (dst != base + size) {
status->status = InvalidArgument( Set_TF_Status_from_Status(
"invalid string tensor encoding (decoded ", (dst - base), status, InvalidArgument(
" bytes, but the tensor is encoded in ", size, " bytes"); "invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes"));
delete[] base; delete[] base;
return nullptr; return nullptr;
} }

View File

@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
struct TF_Tensor {
~TF_Tensor();
TF_DataType dtype;
tensorflow::TensorShape shape;
tensorflow::TensorBuffer* buffer;
};
namespace tensorflow {
class TensorCApi {
public:
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
TensorBuffer* buf) {
return Tensor(static_cast<DataType>(type), shape, buf);
}
};
} // namespace tensorflow
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_