Split tf_tensor_internal into its own header
This is one of a series of changes that allows users of TF_Tensor and other C types to avoid depending on everything in TensorFlow. PiperOrigin-RevId: 253337240
This commit is contained in:
parent
09c5e94be3
commit
652d3e7bc6
@ -77,7 +77,10 @@ tf_cuda_library(
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
],
|
||||
}) + [":tf_status_internal"],
|
||||
}) + [
|
||||
":tf_status_internal",
|
||||
":tf_tensor_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -211,9 +214,10 @@ cc_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -221,6 +225,24 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "tf_tensor_internal",
|
||||
hdrs = [
|
||||
"tf_tensor.h",
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_experimental",
|
||||
srcs = [
|
||||
@ -263,7 +285,7 @@ tf_cuda_library(
|
||||
hdrs = ["tf_status_helper.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_no_xla",
|
||||
":tf_status",
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -329,17 +351,16 @@ tf_cuda_library(
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_no_xla",
|
||||
":c_api_internal",
|
||||
":tf_status_helper",
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_no_xla",
|
||||
":c_api_internal",
|
||||
":tf_status_helper",
|
||||
":tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
@ -357,6 +378,8 @@ tf_cuda_library(
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
@ -365,7 +388,7 @@ tf_cuda_library(
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
}) + [":c_api_internal"],
|
||||
}),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
// clang-format on
|
||||
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
@ -53,14 +54,6 @@ class ServerInterface;
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
// not be depended on.
|
||||
|
||||
struct TF_Tensor {
|
||||
~TF_Tensor();
|
||||
|
||||
TF_DataType dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
tensorflow::TensorBuffer* buffer;
|
||||
};
|
||||
|
||||
struct TF_SessionOptions {
|
||||
tensorflow::SessionOptions options;
|
||||
};
|
||||
@ -193,15 +186,6 @@ struct TF_Server {
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
public:
|
||||
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
|
||||
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
|
||||
TensorBuffer* buf) {
|
||||
return Tensor(static_cast<DataType>(type), shape, buf);
|
||||
}
|
||||
};
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
|
@ -16,12 +16,36 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_KERNELS_H_
|
||||
#define TENSORFLOW_C_KERNELS_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Macro to control visibility of exported symbols in the shared library (.so,
|
||||
// .dylib, .dll).
|
||||
// This duplicates the TF_EXPORT macro definition in
|
||||
// tensorflow/core/platform/macros.h in order to keep this .h file independent
|
||||
// of any other includes.
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct TF_Tensor TF_Tensor;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for TensorFlow Kernels.
|
||||
//
|
||||
|
@ -14,8 +14,12 @@ tf_kernel_library(
|
||||
prefix = "bitcast_op",
|
||||
deps = [
|
||||
"//tensorflow/c:kernels",
|
||||
"//tensorflow/c:ops",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -28,6 +32,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,11 +16,12 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorflow/c/kernels.h"
|
||||
#include "tensorflow/c/ops.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/selective_registration.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
// BitcastOp implements a bitcast kernel, creating an output tensor that shares
|
||||
// the same data buffer as the input but with a different shape and/or data
|
||||
@ -135,9 +136,8 @@ static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
|
||||
TF_DeleteTensor(tensor);
|
||||
}
|
||||
|
||||
static void RegisterBitcastOp() {
|
||||
void RegisterBitcastOp() {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
{
|
||||
auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
|
||||
&BitcastOp_Create, &BitcastOp_Compute,
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/ops.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
@ -73,7 +73,8 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
|
@ -15,10 +15,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
|
||||
@ -227,13 +230,15 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
size_t dst_len, TF_Status* status) {
|
||||
const size_t sz = TF_StringEncodedSize(src_len);
|
||||
if (sz < src_len) {
|
||||
status->status = InvalidArgument("src string is too large to encode");
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument("src string is too large to encode"));
|
||||
return 0;
|
||||
}
|
||||
if (dst_len < sz) {
|
||||
status->status =
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
|
||||
src_len, "-byte string");
|
||||
src_len, "-byte string"));
|
||||
return 0;
|
||||
}
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
@ -259,7 +264,8 @@ static Status TF_StringDecode_Impl(const char* src, size_t src_len,
|
||||
|
||||
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
|
||||
size_t* dst_len, TF_Status* status) {
|
||||
status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
|
||||
Set_TF_Status_from_Status(status,
|
||||
TF_StringDecode_Impl(src, src_len, dst, dst_len));
|
||||
if (TF_GetCode(status) != TF_OK) return 0;
|
||||
return static_cast<size_t>(*dst - src) + *dst_len;
|
||||
}
|
||||
@ -299,8 +305,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
if (!src.IsInitialized()) {
|
||||
status->status = FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value");
|
||||
Set_TF_Status_from_Status(
|
||||
status, FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value"));
|
||||
return nullptr;
|
||||
}
|
||||
if (src.NumElements() == 0) {
|
||||
@ -308,13 +315,14 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
}
|
||||
if (src.dtype() == tensorflow::DT_RESOURCE) {
|
||||
if (src.shape().dims() != 0) {
|
||||
status->status = InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error.");
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error."));
|
||||
return nullptr;
|
||||
}
|
||||
const string str =
|
||||
@ -353,9 +361,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
const string& s = srcarray(i);
|
||||
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
status->status = InvalidArgument(
|
||||
"invalid string tensor encoding (string #", i, " of ",
|
||||
srcarray.size(), "): ", status->status.error_message());
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
|
||||
srcarray.size(), "): ", TF_Message(status)));
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
@ -363,9 +372,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
dst_len -= consumed;
|
||||
}
|
||||
if (dst != base + size) {
|
||||
status->status = InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes"));
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
|
46
tensorflow/c/tf_tensor_internal.h
Normal file
46
tensorflow/c/tf_tensor_internal.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
// Internal structures used by the C API. These are likely to change and should
|
||||
// not be depended on.
|
||||
|
||||
struct TF_Tensor {
|
||||
~TF_Tensor();
|
||||
|
||||
TF_DataType dtype;
|
||||
tensorflow::TensorShape shape;
|
||||
tensorflow::TensorBuffer* buffer;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
public:
|
||||
static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
|
||||
static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
|
||||
TensorBuffer* buf) {
|
||||
return Tensor(static_cast<DataType>(type), shape, buf);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
Loading…
Reference in New Issue
Block a user