diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 572367c40a6..72279ff9b1c 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -77,7 +77,10 @@ tf_cuda_library( "//tensorflow/core:op_gen_lib", "//tensorflow/core/distributed_runtime:server_lib", ], - }) + [":tf_status_internal"], + }) + [ + ":tf_status_internal", + ":tf_tensor_internal", + ], ) cc_library( @@ -211,9 +214,10 @@ cc_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":c_api_internal", ":tf_datatype", ":tf_status", + ":tf_status_helper", + ":tf_tensor_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -221,6 +225,24 @@ cc_library( }), ) +tf_cuda_library( + name = "tf_tensor_internal", + hdrs = [ + "tf_tensor.h", + "tf_tensor_internal.h", + ], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":tf_datatype", + ":tf_status", + "//tensorflow/core:framework", + ], + }), +) + tf_cuda_library( name = "c_api_experimental", srcs = [ @@ -263,7 +285,7 @@ tf_cuda_library( hdrs = ["tf_status_helper.h"], visibility = ["//visibility:public"], deps = [ - ":c_api_no_xla", + ":tf_status", ":tf_status_internal", "//tensorflow/core:lib", ], @@ -329,17 +351,16 @@ tf_cuda_library( ], copts = tf_copts(), visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":tf_status", + ":tf_status_helper", + ] + select({ "//tensorflow:android": [ - ":c_api_no_xla", ":c_api_internal", - ":tf_status_helper", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":c_api_no_xla", ":c_api_internal", - ":tf_status_helper", ":tf_tensor", "//tensorflow/core:framework", ], @@ -357,6 +378,8 @@ tf_cuda_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":tf_datatype", + ":tf_status", ":tf_status_helper", ] + select({ "//tensorflow:android": [ @@ -365,7 +388,7 @@ tf_cuda_library( "//conditions:default": [ "//tensorflow/core:framework", ], - }) + [":c_api_internal"], + }), ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 7254127cd2e..0310ccf247e 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -30,6 +30,7 @@ limitations under the License. // clang-format on #include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/c/tf_tensor_internal.h" #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/core/framework/op_gen_lib.h" #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) @@ -53,14 +54,6 @@ class ServerInterface; // Internal structures used by the C API. These are likely to change and should // not be depended on. -struct TF_Tensor { - ~TF_Tensor(); - - TF_DataType dtype; - tensorflow::TensorShape shape; - tensorflow::TensorBuffer* buffer; -}; - struct TF_SessionOptions { tensorflow::SessionOptions options; }; @@ -193,15 +186,6 @@ struct TF_Server { namespace tensorflow { -class TensorCApi { - public: - static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } - static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, - TensorBuffer* buf) { - return Tensor(static_cast(type), shape, buf); - } -}; - Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index c47bfa8aa3a..582e494ba94 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -16,12 +16,36 @@ limitations under the License. #ifndef TENSORFLOW_C_KERNELS_H_ #define TENSORFLOW_C_KERNELS_H_ -#include "tensorflow/c/c_api.h" +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes. +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG #ifdef __cplusplus extern "C" { #endif +typedef struct TF_Tensor TF_Tensor; + // -------------------------------------------------------------------------- // C API for TensorFlow Kernels. // diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index ce86db26faf..86458bc67b2 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -14,8 +14,12 @@ tf_kernel_library( prefix = "bitcast_op", deps = [ "//tensorflow/c:kernels", + "//tensorflow/c:ops", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_tensor", "//tensorflow/core:framework", - "//tensorflow/core:ops", + "//tensorflow/core:lib", ], ) @@ -28,6 +32,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/c/kernels/bitcast_op.cc b/tensorflow/c/kernels/bitcast_op.cc index f2f313af386..15ef53aa669 100644 --- a/tensorflow/c/kernels/bitcast_op.cc +++ b/tensorflow/c/kernels/bitcast_op.cc @@ -16,11 +16,12 @@ limitations under the License. #include #include "tensorflow/c/kernels.h" +#include "tensorflow/c/ops.h" +#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/types.h" // BitcastOp implements a bitcast kernel, creating an output tensor that shares // the same data buffer as the input but with a different shape and/or data @@ -135,9 +136,8 @@ static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) { TF_DeleteTensor(tensor); } -static void RegisterBitcastOp() { +void RegisterBitcastOp() { TF_Status* status = TF_NewStatus(); - { auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU, &BitcastOp_Create, &BitcastOp_Compute, diff --git a/tensorflow/c/ops.cc b/tensorflow/c/ops.cc index d806a16cbc0..f3c8bf5cf04 100644 --- a/tensorflow/c/ops.cc +++ b/tensorflow/c/ops.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/c/ops.h" -#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference.h" diff --git a/tensorflow/c/ops.h b/tensorflow/c/ops.h index 6f941a06fbc..14868e40260 100644 --- a/tensorflow/c/ops.h +++ b/tensorflow/c/ops.h @@ -73,7 +73,8 @@ limitations under the License. #include #include -#include "tensorflow/c/c_api.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" #ifdef SWIG #define TF_CAPI_EXPORT diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 97b760f84f8..deb36166a47 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -15,10 +15,13 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" -#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/coding.h" @@ -227,13 +230,15 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t dst_len, TF_Status* status) { const size_t sz = TF_StringEncodedSize(src_len); if (sz < src_len) { - status->status = InvalidArgument("src string is too large to encode"); + Set_TF_Status_from_Status( + status, InvalidArgument("src string is too large to encode")); return 0; } if (dst_len < sz) { - status->status = + Set_TF_Status_from_Status( + status, InvalidArgument("dst_len (", dst_len, ") too small to encode a ", - src_len, "-byte string"); + src_len, "-byte string")); return 0; } dst = tensorflow::core::EncodeVarint64(dst, src_len); @@ -259,7 +264,8 @@ static Status TF_StringDecode_Impl(const char* src, size_t src_len, size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, size_t* dst_len, TF_Status* status) { - status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len); + Set_TF_Status_from_Status(status, + TF_StringDecode_Impl(src, src_len, dst, dst_len)); if (TF_GetCode(status) != TF_OK) return 0; return static_cast(*dst - src) + *dst_len; } @@ -299,8 +305,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); if (!src.IsInitialized()) { - status->status = FailedPrecondition( - "attempt to use a tensor with an uninitialized value"); + Set_TF_Status_from_Status( + status, FailedPrecondition( + "attempt to use a tensor with an uninitialized value")); return nullptr; } if (src.NumElements() == 0) { @@ -308,13 +315,14 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, } if (src.dtype() == tensorflow::DT_RESOURCE) { if (src.shape().dims() != 0) { - status->status = InvalidArgument( - "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ", - src.shape().DebugString(), - "). Please file a bug at " - "https://github.com/tensorflow/tensorflow/issues/new, " - "ideally with a " - "short code snippet that reproduces this error."); + Set_TF_Status_from_Status( + status, InvalidArgument( + "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ", + src.shape().DebugString(), + "). Please file a bug at " + "https://github.com/tensorflow/tensorflow/issues/new, " + "ideally with a " + "short code snippet that reproduces this error.")); return nullptr; } const string str = @@ -353,9 +361,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, const string& s = srcarray(i); size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); if (TF_GetCode(status) != TF_OK) { - status->status = InvalidArgument( - "invalid string tensor encoding (string #", i, " of ", - srcarray.size(), "): ", status->status.error_message()); + Set_TF_Status_from_Status( + status, + InvalidArgument("invalid string tensor encoding (string #", i, " of ", + srcarray.size(), "): ", TF_Message(status))); delete[] base; return nullptr; } @@ -363,9 +372,10 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, dst_len -= consumed; } if (dst != base + size) { - status->status = InvalidArgument( - "invalid string tensor encoding (decoded ", (dst - base), - " bytes, but the tensor is encoded in ", size, " bytes"); + Set_TF_Status_from_Status( + status, InvalidArgument( + "invalid string tensor encoding (decoded ", (dst - base), + " bytes, but the tensor is encoded in ", size, " bytes")); delete[] base; return nullptr; } diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h new file mode 100644 index 00000000000..6def66c9412 --- /dev/null +++ b/tensorflow/c/tf_tensor_internal.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ +#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +// Internal structures used by the C API. These are likely to change and should +// not be depended on. + +struct TF_Tensor { + ~TF_Tensor(); + + TF_DataType dtype; + tensorflow::TensorShape shape; + tensorflow::TensorBuffer* buffer; +}; + +namespace tensorflow { + +class TensorCApi { + public: + static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } + static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, + TensorBuffer* buf) { + return Tensor(static_cast(type), shape, buf); + } +}; + +} // namespace tensorflow +#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_