Initial checkin of C++ header-only TensorHandle as part of RFC https://github.com/tensorflow/community/pull/207.
PiperOrigin-RevId: 311179503 Change-Id: Ib3cfb2547150d09ee655db6ca6bc72ef3ef7adde
This commit is contained in:
parent
88dfd8ce6d
commit
5100abc4af
@ -924,7 +924,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
|||||||
context->GetDevicePlacementPolicy());
|
context->GetDevicePlacementPolicy());
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||||
tensorflow::Tensor tensor;
|
tensorflow::Tensor tensor;
|
||||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
|
@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
|||||||
// placed in memory of different devices or remote address spaces.
|
// placed in memory of different devices or remote address spaces.
|
||||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
// Indicates that the caller will not be using `h` any more.
|
// Indicates that the caller will not be using `h` any more.
|
||||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||||
|
@ -62,3 +62,17 @@ cc_library(
|
|||||||
"//tensorflow/c:tf_tensor",
|
"//tensorflow/c:tf_tensor",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensorhandle",
|
||||||
|
hdrs = [
|
||||||
|
"tensorhandle.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":runtime",
|
||||||
|
":status",
|
||||||
|
":tensor",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||||
@ -40,6 +41,7 @@ class Runtime {
|
|||||||
private:
|
private:
|
||||||
friend class RuntimeBuilder;
|
friend class RuntimeBuilder;
|
||||||
friend class SavedModelAPI;
|
friend class SavedModelAPI;
|
||||||
|
friend class TensorHandle;
|
||||||
|
|
||||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||||
@ -63,6 +65,7 @@ class Runtime {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||||
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// Status is a wrapper around an error code and an optional error message.
|
// Status is a wrapper around an error code and an optional error message.
|
||||||
@ -57,6 +58,7 @@ class Status {
|
|||||||
friend class RuntimeBuilder;
|
friend class RuntimeBuilder;
|
||||||
friend class Runtime;
|
friend class Runtime;
|
||||||
friend class SavedModelAPI;
|
friend class SavedModelAPI;
|
||||||
|
friend class TensorHandle;
|
||||||
|
|
||||||
// Wraps a TF_Status*, and takes ownership of it.
|
// Wraps a TF_Status*, and takes ownership of it.
|
||||||
explicit Status(TF_Status* status) : status_(status) {}
|
explicit Status(TF_Status* status) : status_(status) {}
|
||||||
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// Tensor represents an n-dimensional array of values.
|
// Tensor represents an n-dimensional array of values.
|
||||||
@ -168,6 +169,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||||
|
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||||
|
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
|
namespace cc {
|
||||||
|
|
||||||
|
// An opaque representation of a tensor computed/managed by the Tensorflow
|
||||||
|
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
|
||||||
|
// to tensors placed in memory of different devices or remote address spaces.
|
||||||
|
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
|
||||||
|
// from it.
|
||||||
|
class TensorHandle {
|
||||||
|
public:
|
||||||
|
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
|
||||||
|
// status->ok() will be false, and the returned Tensor must not be used.
|
||||||
|
Tensor Resolve(Status* status);
|
||||||
|
|
||||||
|
// Constructs a TensorHandle from a Tensor. If an error occurred,
|
||||||
|
// status->ok() will be false, and the returned TensorHandle must not be used.
|
||||||
|
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
|
||||||
|
Status* status);
|
||||||
|
|
||||||
|
// TensorHandle is movable, and not copyable
|
||||||
|
TensorHandle(TensorHandle&&) = default;
|
||||||
|
TensorHandle& operator=(TensorHandle&&) = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Wraps a TFE_TensorHandle. Takes ownership of handle.
|
||||||
|
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
|
||||||
|
|
||||||
|
// TensorHandle is not copyable
|
||||||
|
TensorHandle(const TensorHandle&) = delete;
|
||||||
|
TensorHandle& operator=(const TensorHandle&) = delete;
|
||||||
|
|
||||||
|
// Returns the underlying TFE_TensorHandle that this object wraps.
|
||||||
|
// This object retains ownership of the pointer.
|
||||||
|
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
|
||||||
|
|
||||||
|
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
|
||||||
|
// and takes ownership of handle.
|
||||||
|
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
|
||||||
|
|
||||||
|
struct TFETensorHandleDeleter {
|
||||||
|
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
|
||||||
|
};
|
||||||
|
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Tensor TensorHandle::Resolve(Status* status) {
|
||||||
|
TF_Tensor* tensor =
|
||||||
|
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
|
||||||
|
if (!status->ok()) {
|
||||||
|
return Tensor(nullptr);
|
||||||
|
}
|
||||||
|
return Tensor(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
|
||||||
|
const Runtime& runtime,
|
||||||
|
Status* status) {
|
||||||
|
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
|
||||||
|
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
|
||||||
|
if (!status->ok()) {
|
||||||
|
return TensorHandle(nullptr);
|
||||||
|
}
|
||||||
|
return TensorHandle(tensor_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
@ -5,12 +5,22 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_types_test_util",
|
||||||
|
testonly = True,
|
||||||
|
hdrs = ["tensor_types_test_util.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "tensor_test",
|
name = "tensor_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
"tensor_test.cc",
|
"tensor_test.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":tensor_types_test_util",
|
||||||
"//tensorflow/c:tf_datatype",
|
"//tensorflow/c:tf_datatype",
|
||||||
"//tensorflow/cc/experimental/base/public:status",
|
"//tensorflow/cc/experimental/base/public:status",
|
||||||
"//tensorflow/cc/experimental/base/public:tensor",
|
"//tensorflow/cc/experimental/base/public:tensor",
|
||||||
@ -19,3 +29,22 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "tensorhandle_test",
|
||||||
|
srcs = [
|
||||||
|
"tensorhandle_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensor_types_test_util",
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
"//tensorflow/cc/experimental/base/public:runtime",
|
||||||
|
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||||
|
"//tensorflow/cc/experimental/base/public:status",
|
||||||
|
"//tensorflow/cc/experimental/base/public:tensor",
|
||||||
|
"//tensorflow/cc/experimental/base/public:tensorhandle",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -16,69 +16,22 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Each of the following struct types have two members: a kDType that
|
using tensorflow::experimental::cc::Status;
|
||||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
using tensorflow::experimental::cc::Tensor;
|
||||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
|
||||||
// tests via GoogleTest's TypedTests:
|
|
||||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
|
||||||
struct FloatType {
|
|
||||||
using type = float;
|
|
||||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DoubleType {
|
using SimpleTypes = ::testing::Types<
|
||||||
using type = double;
|
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||||
};
|
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||||
|
|
||||||
struct Int32Type {
|
|
||||||
using type = int32_t;
|
|
||||||
static constexpr TF_DataType kDType = TF_INT32;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct UINT8Type {
|
|
||||||
using type = uint8_t;
|
|
||||||
static constexpr TF_DataType kDType = TF_UINT8;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct INT8Type {
|
|
||||||
using type = int8_t;
|
|
||||||
static constexpr TF_DataType kDType = TF_INT8;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct INT64Type {
|
|
||||||
using type = int64_t;
|
|
||||||
static constexpr TF_DataType kDType = TF_INT64;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct UINT16Type {
|
|
||||||
using type = uint16_t;
|
|
||||||
static constexpr TF_DataType kDType = TF_UINT16;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct UINT32Type {
|
|
||||||
using type = uint32_t;
|
|
||||||
static constexpr TF_DataType kDType = TF_UINT32;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct UINT64Type {
|
|
||||||
using type = uint64_t;
|
|
||||||
static constexpr TF_DataType kDType = TF_UINT64;
|
|
||||||
};
|
|
||||||
|
|
||||||
using SimpleTypes =
|
|
||||||
::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
|
|
||||||
INT64Type, UINT16Type, UINT32Type, UINT64Type>;
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class ConstructScalarTensorTest : public ::testing::Test {};
|
class ConstructScalarTensorTest : public ::testing::Test {};
|
||||||
@ -88,11 +41,10 @@ TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
|
|||||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
// number of elements.
|
// number of elements.
|
||||||
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||||
cc::Status status;
|
Status status;
|
||||||
TF_DataType dtype = TypeParam::kDType;
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
typename TypeParam::type value = 42;
|
typename TypeParam::type value = 42;
|
||||||
cc::Tensor tensor =
|
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||||
cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
|
||||||
/*data=*/&value,
|
/*data=*/&value,
|
||||||
/*len=*/sizeof(value),
|
/*len=*/sizeof(value),
|
||||||
/*deleter=*/[](void*, size_t) {}, &status);
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
@ -113,7 +65,7 @@ TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
|
|||||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
// number of elements.
|
// number of elements.
|
||||||
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||||
cc::Status status;
|
Status status;
|
||||||
TF_DataType dtype = TypeParam::kDType;
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
// This is our 1D tensor of varying dtype.
|
// This is our 1D tensor of varying dtype.
|
||||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||||
@ -121,7 +73,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
|||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
shape.push_back(value.size());
|
shape.push_back(value.size());
|
||||||
|
|
||||||
cc::Tensor tensor = cc::Tensor::FromBuffer(
|
Tensor tensor = Tensor::FromBuffer(
|
||||||
/*dtype=*/dtype, /*shape=*/shape,
|
/*dtype=*/dtype, /*shape=*/shape,
|
||||||
/*data=*/value.data(),
|
/*data=*/value.data(),
|
||||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||||
@ -130,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
|||||||
|
|
||||||
EXPECT_EQ(tensor.dims(), 1);
|
EXPECT_EQ(tensor.dims(), 1);
|
||||||
EXPECT_EQ(tensor.dtype(), dtype);
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||||
EXPECT_EQ(tensor_view[0], 42);
|
EXPECT_EQ(tensor_view[0], 42);
|
||||||
EXPECT_EQ(tensor_view[1], 100);
|
EXPECT_EQ(tensor_view[1], 100);
|
||||||
@ -152,14 +104,14 @@ TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
|
|||||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
// number of elements.
|
// number of elements.
|
||||||
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||||
cc::Status status;
|
Status status;
|
||||||
TF_DataType dtype = TypeParam::kDType;
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
// This is our 1D tensor of varying dtype.
|
// This is our 1D tensor of varying dtype.
|
||||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||||
// Shape is Rank 2 vector with shape 2 x 3.
|
// Shape is Rank 2 vector with shape 2 x 3.
|
||||||
std::vector<int64_t> shape({2, 3});
|
std::vector<int64_t> shape({2, 3});
|
||||||
|
|
||||||
cc::Tensor tensor = cc::Tensor::FromBuffer(
|
Tensor tensor = Tensor::FromBuffer(
|
||||||
/*dtype=*/dtype, /*shape=*/shape,
|
/*dtype=*/dtype, /*shape=*/shape,
|
||||||
/*data=*/value.data(),
|
/*data=*/value.data(),
|
||||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||||
@ -169,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
|||||||
|
|
||||||
EXPECT_EQ(tensor.dims(), 2);
|
EXPECT_EQ(tensor.dims(), 2);
|
||||||
EXPECT_EQ(tensor.dtype(), dtype);
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||||
EXPECT_EQ(tensor_view[0], 42);
|
EXPECT_EQ(tensor_view[0], 42);
|
||||||
EXPECT_EQ(tensor_view[1], 100);
|
EXPECT_EQ(tensor_view[1], 100);
|
||||||
@ -185,19 +137,19 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
|||||||
|
|
||||||
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||||
bool done = false;
|
bool done = false;
|
||||||
cc::Status status;
|
Status status;
|
||||||
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
||||||
{
|
{
|
||||||
// data_vector is a rank 1 tensor.
|
// data_vector is a rank 1 tensor.
|
||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
shape.push_back(data_vector.size());
|
shape.push_back(data_vector.size());
|
||||||
|
|
||||||
cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||||
done = true;
|
done = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
cc::Tensor tensor =
|
Tensor tensor =
|
||||||
cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||||
/*data=*/data_vector.data(),
|
/*data=*/data_vector.data(),
|
||||||
/*len=*/data_vector.size() * sizeof(int32_t),
|
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||||
/*deleter=*/callback, &status);
|
/*deleter=*/callback, &status);
|
||||||
@ -209,4 +161,3 @@ TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
|
||||||
|
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||||
|
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Each of the following struct types have two members: a kDType that
|
||||||
|
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||||
|
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||||
|
// tests via GoogleTest's TypedTests:
|
||||||
|
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||||
|
struct FloatType {
|
||||||
|
using type = float;
|
||||||
|
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DoubleType {
|
||||||
|
using type = double;
|
||||||
|
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Int32Type {
|
||||||
|
using type = int32_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_INT32;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UINT8Type {
|
||||||
|
using type = uint8_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_UINT8;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct INT8Type {
|
||||||
|
using type = int8_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_INT8;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct INT64Type {
|
||||||
|
using type = int64_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_INT64;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UINT16Type {
|
||||||
|
using type = uint16_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_UINT16;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UINT32Type {
|
||||||
|
using type = uint32_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_UINT32;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UINT64Type {
|
||||||
|
using type = uint64_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_UINT64;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::experimental::cc::Runtime;
|
||||||
|
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||||
|
using tensorflow::experimental::cc::Status;
|
||||||
|
using tensorflow::experimental::cc::Tensor;
|
||||||
|
using tensorflow::experimental::cc::TensorHandle;
|
||||||
|
|
||||||
|
using SimpleTypes = ::testing::Types<
|
||||||
|
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||||
|
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||||
|
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class ConstructScalarTensorHandleTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||||
|
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
|
||||||
|
// verify the expected dims, dtype, value, num bytes, and num elements.
|
||||||
|
TYPED_TEST(ConstructScalarTensorHandleTest,
|
||||||
|
ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
RuntimeBuilder runtime_builder;
|
||||||
|
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
typename TypeParam::type value = 42;
|
||||||
|
Tensor original_tensor =
|
||||||
|
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||||
|
/*data=*/&value,
|
||||||
|
/*len=*/sizeof(value),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TensorHandle handle =
|
||||||
|
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
Tensor tensor = handle.Resolve(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 0);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||||
|
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Construct1DTensorHandleTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||||
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
|
// number of elements.
|
||||||
|
TYPED_TEST(Construct1DTensorHandleTest,
|
||||||
|
ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
RuntimeBuilder runtime_builder;
|
||||||
|
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
// This is our 1D tensor of varying dtype.
|
||||||
|
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||||
|
// Shape is Rank 1 vector.
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
shape.push_back(value.size());
|
||||||
|
|
||||||
|
Tensor original_tensor = Tensor::FromBuffer(
|
||||||
|
/*dtype=*/dtype, /*shape=*/shape,
|
||||||
|
/*data=*/value.data(),
|
||||||
|
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TensorHandle handle =
|
||||||
|
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
Tensor tensor = handle.Resolve(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 1);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||||
|
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||||
|
EXPECT_EQ(tensor_view[0], 42);
|
||||||
|
EXPECT_EQ(tensor_view[1], 100);
|
||||||
|
EXPECT_EQ(tensor_view[2], 0);
|
||||||
|
EXPECT_EQ(tensor_view[3], 1);
|
||||||
|
EXPECT_EQ(tensor_view[4], 4);
|
||||||
|
EXPECT_EQ(tensor_view[5], 29);
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.num_bytes(),
|
||||||
|
value.size() * sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Construct2DTensorHandleTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||||
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
|
// number of elements.
|
||||||
|
TYPED_TEST(Construct2DTensorHandleTest,
|
||||||
|
ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
RuntimeBuilder runtime_builder;
|
||||||
|
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
// This is our 1D tensor of varying dtype.
|
||||||
|
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||||
|
// Shape is Rank 2 vector with shape 2 x 3.
|
||||||
|
std::vector<int64_t> shape({2, 3});
|
||||||
|
|
||||||
|
Tensor original_tensor = Tensor::FromBuffer(
|
||||||
|
/*dtype=*/dtype, /*shape=*/shape,
|
||||||
|
/*data=*/value.data(),
|
||||||
|
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TensorHandle handle =
|
||||||
|
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
Tensor tensor = handle.Resolve(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 2);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||||
|
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||||
|
EXPECT_EQ(tensor_view[0], 42);
|
||||||
|
EXPECT_EQ(tensor_view[1], 100);
|
||||||
|
EXPECT_EQ(tensor_view[2], 0);
|
||||||
|
EXPECT_EQ(tensor_view[3], 1);
|
||||||
|
EXPECT_EQ(tensor_view[4], 4);
|
||||||
|
EXPECT_EQ(tensor_view[5], 29);
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.num_bytes(),
|
||||||
|
value.size() * sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
||||||
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
||||||
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// FunctionMetadata stores additional function information, including
|
// FunctionMetadata stores additional function information, including
|
||||||
@ -40,6 +41,7 @@ class FunctionMetadata final {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
||||||
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||||
|
@ -26,10 +26,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/stringpiece.h"
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::experimental::cc::Runtime;
|
||||||
|
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||||
|
using tensorflow::experimental::cc::SavedModelAPI;
|
||||||
|
using tensorflow::experimental::cc::Status;
|
||||||
|
|
||||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||||
|
|
||||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||||
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
|||||||
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||||
|
|
||||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||||
cc::Status status;
|
Status status;
|
||||||
cc::RuntimeBuilder builder;
|
RuntimeBuilder builder;
|
||||||
bool use_tfrt = GetParam();
|
bool use_tfrt = GetParam();
|
||||||
if (use_tfrt) {
|
if (use_tfrt) {
|
||||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.SetUseTFRT(use_tfrt);
|
builder.SetUseTFRT(use_tfrt);
|
||||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||||
ASSERT_TRUE(status.ok()) << status.message();
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||||
std::unordered_set<std::string> tags = {"serve"};
|
std::unordered_set<std::string> tags = {"serve"};
|
||||||
std::unique_ptr<cc::SavedModelAPI> model =
|
std::unique_ptr<SavedModelAPI> model =
|
||||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||||
|
|
||||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||||
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||||
cc::Status status;
|
Status status;
|
||||||
cc::RuntimeBuilder builder;
|
RuntimeBuilder builder;
|
||||||
bool use_tfrt = GetParam();
|
bool use_tfrt = GetParam();
|
||||||
if (use_tfrt) {
|
if (use_tfrt) {
|
||||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.SetUseTFRT(use_tfrt);
|
builder.SetUseTFRT(use_tfrt);
|
||||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||||
ASSERT_TRUE(status.ok()) << status.message();
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||||
std::unique_ptr<cc::SavedModelAPI> model =
|
std::unique_ptr<SavedModelAPI> model =
|
||||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
|
SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||||
|
|
||||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||||
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user