From 1f63a2e6100924ac51042defcf485198e9259309 Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Tue, 29 Sep 2020 17:20:49 -0700 Subject: [PATCH] Adding an opaque TF_Shape type to the C API. This is necessary to represent shapes for TF_TensorSpec. PiperOrigin-RevId: 334495752 Change-Id: I707f8c6b2f8568c7187e638d426e983dc3484412 --- tensorflow/c/BUILD | 24 +++++++++++++++ tensorflow/c/tf_shape.cc | 39 +++++++++++++++++++++++++ tensorflow/c/tf_shape.h | 50 ++++++++++++++++++++++++++++++++ tensorflow/c/tf_shape_internal.h | 30 +++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 tensorflow/c/tf_shape.cc create mode 100644 tensorflow/c/tf_shape.h create mode 100644 tensorflow/c/tf_shape_internal.h diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 16f6b860308..677ab3355ff 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -254,6 +254,30 @@ tf_cuda_library( }), ) +cc_library( + name = "tf_shape", + srcs = ["tf_shape.cc"], + hdrs = ["tf_shape.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":c_api_macros", + ":tf_shape_internal", + "//tensorflow/core:framework", + ], +) + +cc_library( + name = "tf_shape_internal", + hdrs = ["tf_shape_internal.h"], + copts = tf_copts(), + visibility = ["//tensorflow:internal"], + deps = [ + ":conversion_macros", + "//tensorflow/core:framework", + ], +) + cc_library( name = "tf_status", srcs = ["tf_status.cc"], diff --git a/tensorflow/c/tf_shape.cc b/tensorflow/c/tf_shape.cc new file mode 100644 index 00000000000..a715544a13f --- /dev/null +++ b/tensorflow/c/tf_shape.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/tf_shape.h" + +#include + +#include "tensorflow/c/tf_shape_internal.h" +#include "tensorflow/core/framework/tensor_shape.h" + +extern "C" { + +TF_Shape* TF_NewShape() { + return tensorflow::wrap(new tensorflow::PartialTensorShape()); +} + +int TF_ShapeDims(const TF_Shape* shape) { + return tensorflow::unwrap(shape)->dims(); +} + +int64_t TF_ShapeDimSize(const TF_Shape* shape, int d) { + return tensorflow::unwrap(shape)->dim_size(d); +} + +void TF_DeleteShape(TF_Shape* shape) { delete tensorflow::unwrap(shape); } + +} // end extern "C" diff --git a/tensorflow/c/tf_shape.h b/tensorflow/c/tf_shape.h new file mode 100644 index 00000000000..f218d05e274 --- /dev/null +++ b/tensorflow/c/tf_shape.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/c_api_macros.h" + +#ifndef TENSORFLOW_C_TF_SHAPE_H_ +#define TENSORFLOW_C_TF_SHAPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// An opaque type corresponding to a shape in tensorflow. In the future, +// we may expose the ABI of TF_Shape for performance reasons. +typedef struct TF_Shape TF_Shape; + +// Return a new, unknown rank shape object. The caller is responsible for +// calling TF_DeleteShape to deallocate and destroy the returned shape. +TF_CAPI_EXPORT extern TF_Shape* TF_NewShape(); + +// Returns the rank of `shape`. If `shape` has unknown rank, returns -1. +TF_CAPI_EXPORT extern int TF_ShapeDims(const TF_Shape* shape); + +// Returns the `d`th dimension of `shape`. If `shape` has unknown rank, +// invoking this function is undefined behavior. Returns -1 if dimension is +// unknown. +TF_CAPI_EXPORT extern int64_t TF_ShapeDimSize(const TF_Shape* shape, int d); + +// Deletes `shape`. +TF_CAPI_EXPORT extern void TF_DeleteShape(TF_Shape* shape); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_SHAPE_H_ diff --git a/tensorflow/c/tf_shape_internal.h b/tensorflow/c/tf_shape_internal.h new file mode 100644 index 00000000000..fe97726460f --- /dev/null +++ b/tensorflow/c/tf_shape_internal.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ +#define TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/core/framework/tensor_shape.h" + +typedef struct TF_Shape TF_Shape; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::PartialTensorShape, TF_Shape); + +} + +#endif // TENSORFLOW_C_TF_SHAPE_INTERNAL_H_