From 9849fde5e7b4da4b630ffbc517fad68b2b811c0c Mon Sep 17 00:00:00 2001 From: James Ring <sjr@google.com> Date: Thu, 23 May 2019 11:40:35 -0700 Subject: [PATCH] Split TF_DataType into its own header. This change allows users of TF_DataType to include it without including the entire (heavyweight) C API. It is one in a series of changes to allow op definitions in an ABI-stable C API. PiperOrigin-RevId: 249684540 --- tensorflow/c/BUILD | 19 ++++++ tensorflow/c/c_api.cc | 6 -- tensorflow/c/c_api.h | 39 +------------ tensorflow/c/tf_datatype.cc | 23 ++++++++ tensorflow/c/tf_datatype.h | 83 +++++++++++++++++++++++++++ tensorflow/contrib/makefile/Makefile | 1 + tensorflow/core/framework/types.proto | 2 +- tensorflow/python/pywrap_tfe.i | 1 + 8 files changed, 130 insertions(+), 44 deletions(-) create mode 100644 tensorflow/c/tf_datatype.cc create mode 100644 tensorflow/c/tf_datatype.h diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 978c65a865e..99eb28c1295 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -22,6 +22,7 @@ filegroup( "c_api.h", "c_api_experimental.h", "tf_attrtype.h", + "tf_datatype.h", "tf_status.h", ], visibility = ["//tensorflow:__subpackages__"], @@ -53,6 +54,7 @@ tf_cuda_library( hdrs = [ "c_api.h", "c_api_internal.h", + "tf_datatype.h", "tf_status.h", ], visibility = [ @@ -86,6 +88,7 @@ tf_cuda_library( hdrs = [ "c_api.h", "tf_attrtype.h", + "tf_datatype.h", "tf_status.h", ], copts = tf_copts(), @@ -117,6 +120,7 @@ tf_cuda_library( deps = [ ":c_api_internal", ":tf_attrtype", + ":tf_datatype", ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", @@ -159,6 +163,21 @@ cc_library( }), ) +cc_library( + name = "tf_datatype", + srcs = ["tf_datatype.cc"], + hdrs = ["tf_datatype.h"], + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }), +) + tf_cuda_library( name = "c_api_experimental", srcs = [ diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 7dce62f90f7..4f519a7bd11 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -108,12 +108,6 @@ extern "C" { // -------------------------------------------------------------------------- const char* TF_Version() { return TF_VERSION_STRING; } -// -------------------------------------------------------------------------- -size_t TF_DataTypeSize(TF_DataType dt) { - return static_cast<size_t>( - tensorflow::DataTypeSize(static_cast<DataType>(dt))); -} - // -------------------------------------------------------------------------- namespace { diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 5387a57204d..9a538cb98db 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -20,6 +20,7 @@ limitations under the License. #include <stdint.h> #include "tensorflow/c/tf_attrtype.h" +#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" // -------------------------------------------------------------------------- @@ -72,7 +73,7 @@ limitations under the License. // .dylib, .dll). // This duplicates the TF_EXPORT macro definition in // tensorflow/core/platform/macros.h in order to keep this .h file independent -// of any other includes.$a +// of any other includes. #ifdef SWIG #define TF_CAPI_EXPORT #else @@ -96,42 +97,6 @@ extern "C" { // TensorFlow library. TensorFlow using semantic versioning. TF_CAPI_EXPORT extern const char* TF_Version(void); -// -------------------------------------------------------------------------- -// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. -// The enum values here are identical to corresponding values in types.proto. -typedef enum TF_DataType { - TF_FLOAT = 1, - TF_DOUBLE = 2, - TF_INT32 = 3, // Int32 tensors are always in 'host' memory. - TF_UINT8 = 4, - TF_INT16 = 5, - TF_INT8 = 6, - TF_STRING = 7, - TF_COMPLEX64 = 8, // Single-precision complex - TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility - TF_INT64 = 9, - TF_BOOL = 10, - TF_QINT8 = 11, // Quantized int8 - TF_QUINT8 = 12, // Quantized uint8 - TF_QINT32 = 13, // Quantized int32 - TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. - TF_QINT16 = 15, // Quantized int16 - TF_QUINT16 = 16, // Quantized uint16 - TF_UINT16 = 17, - TF_COMPLEX128 = 18, // Double-precision complex - TF_HALF = 19, - TF_RESOURCE = 20, - TF_VARIANT = 21, - TF_UINT32 = 22, - TF_UINT64 = 23, -} TF_DataType; - -// TF_DataTypeSize returns the sizeof() for the underlying type corresponding -// to the given TF_DataType enum value. Returns 0 for variable length types -// (eg. TF_STRING) or on failure. -TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); - - // -------------------------------------------------------------------------- // TF_Buffer holds a pointer to a block of data and its associated length. // Typically, the data consists of a serialized protocol buffer, but other data diff --git a/tensorflow/c/tf_datatype.cc b/tensorflow/c/tf_datatype.cc new file mode 100644 index 00000000000..d2a66d99dac --- /dev/null +++ b/tensorflow/c/tf_datatype.cc @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/tf_datatype.h" + +#include "tensorflow/core/framework/types.h" + +size_t TF_DataTypeSize(TF_DataType dt) { + return static_cast<size_t>( + tensorflow::DataTypeSize(static_cast<tensorflow::DataType>(dt))); +} diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h new file mode 100644 index 00000000000..3e6121bf989 --- /dev/null +++ b/tensorflow/c/tf_datatype.h @@ -0,0 +1,83 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_DATATYPE_H_ +#define TENSORFLOW_C_TF_DATATYPE_H_ + +#include <stddef.h> + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes. +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. +// The enum values here are identical to corresponding values in types.proto. +typedef enum TF_DataType { + TF_FLOAT = 1, + TF_DOUBLE = 2, + TF_INT32 = 3, // Int32 tensors are always in 'host' memory. + TF_UINT8 = 4, + TF_INT16 = 5, + TF_INT8 = 6, + TF_STRING = 7, + TF_COMPLEX64 = 8, // Single-precision complex + TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility + TF_INT64 = 9, + TF_BOOL = 10, + TF_QINT8 = 11, // Quantized int8 + TF_QUINT8 = 12, // Quantized uint8 + TF_QINT32 = 13, // Quantized int32 + TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. + TF_QINT16 = 15, // Quantized int16 + TF_QUINT16 = 16, // Quantized uint16 + TF_UINT16 = 17, + TF_COMPLEX128 = 18, // Double-precision complex + TF_HALF = 19, + TF_RESOURCE = 20, + TF_VARIANT = 21, + TF_UINT32 = 22, + TF_UINT64 = 23, +} TF_DataType; + +// TF_DataTypeSize returns the sizeof() for the underlying type corresponding +// to the given TF_DataType enum value. Returns 0 for variable length types +// (eg. TF_STRING) or on failure. +TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_DATATYPE_H_ diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 05d17c28f82..ba0ea348ef8 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -636,6 +636,7 @@ CORE_CC_ALL_SRCS := \ $(ABSL_CC_SRCS) \ tensorflow/c/c_api.cc \ tensorflow/c/kernels.cc \ +tensorflow/c/tf_datatype.cc \ tensorflow/c/tf_status.cc \ tensorflow/c/tf_status_helper.cc \ $(wildcard tensorflow/core/*.cc) \ diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto index 432fbf5bed3..5356f9f9c99 100644 --- a/tensorflow/core/framework/types.proto +++ b/tensorflow/core/framework/types.proto @@ -67,7 +67,7 @@ enum DataType { DT_UINT64_REF = 123; } // LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/c/tf_datatype.h, // https://www.tensorflow.org/code/tensorflow/go/tensor.go, // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, // https://www.tensorflow.org/code/tensorflow/core/framework/types.h, diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 712029d8ee0..deb85122629 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ %include "tensorflow/python/platform/base.i" +%include "tensorflow/c/tf_datatype.h" %include "tensorflow/c/tf_status.h" %ignore "";