diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 73c2f7824b2..5c01ccb82bb 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -924,7 +924,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( context->GetDevicePlacementPolicy()); } -TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { +TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 070b3a9bb60..5afe3047dd7 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, // placed in memory of different devices or remote address spaces. typedef struct TFE_TensorHandle TFE_TensorHandle; -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status); // Indicates that the caller will not be using `h` any more. TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD index 93acf1bd319..045d4e6cd97 100644 --- a/tensorflow/cc/experimental/base/public/BUILD +++ b/tensorflow/cc/experimental/base/public/BUILD @@ -62,3 +62,17 @@ cc_library( "//tensorflow/c:tf_tensor", ], ) + +cc_library( + name = "tensorhandle", + hdrs = [ + "tensorhandle.h", + ], + deps = [ + ":runtime", + ":status", + ":tensor", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) diff --git a/tensorflow/cc/experimental/base/public/runtime.h b/tensorflow/cc/experimental/base/public/runtime.h index 47fd8869647..711a38c233a 100644 --- a/tensorflow/cc/experimental/base/public/runtime.h +++ b/tensorflow/cc/experimental/base/public/runtime.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" namespace tensorflow { +namespace experimental { namespace cc { // Runtime represents an opaque instance of a Tensorflow runtime, with its own @@ -40,6 +41,7 @@ class Runtime { private: friend class RuntimeBuilder; friend class SavedModelAPI; + friend class TensorHandle; // Wraps a TFE_Context. Takes ownership of ctx. explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {} @@ -63,6 +65,7 @@ class Runtime { }; } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ diff --git a/tensorflow/cc/experimental/base/public/runtime_builder.h b/tensorflow/cc/experimental/base/public/runtime_builder.h index ed3c93ae135..737e06cb2c6 100644 --- a/tensorflow/cc/experimental/base/public/runtime_builder.h +++ b/tensorflow/cc/experimental/base/public/runtime_builder.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/experimental/base/public/status.h" namespace tensorflow { +namespace experimental { namespace cc { // RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime. @@ -79,6 +80,7 @@ inline std::unique_ptr RuntimeBuilder::Build(Status* status) { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ diff --git a/tensorflow/cc/experimental/base/public/status.h b/tensorflow/cc/experimental/base/public/status.h index f91f2caccd8..98c8cf6ced2 100644 --- a/tensorflow/cc/experimental/base/public/status.h +++ b/tensorflow/cc/experimental/base/public/status.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" namespace tensorflow { +namespace experimental { namespace cc { // Status is a wrapper around an error code and an optional error message. @@ -57,6 +58,7 @@ class Status { friend class RuntimeBuilder; friend class Runtime; friend class SavedModelAPI; + friend class TensorHandle; // Wraps a TF_Status*, and takes ownership of it. explicit Status(TF_Status* status) : status_(status) {} @@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ diff --git a/tensorflow/cc/experimental/base/public/tensor.h b/tensorflow/cc/experimental/base/public/tensor.h index 26b0e5dc55e..fc447262ce1 100644 --- a/tensorflow/cc/experimental/base/public/tensor.h +++ b/tensorflow/cc/experimental/base/public/tensor.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/cc/experimental/base/public/status.h" namespace tensorflow { +namespace experimental { namespace cc { // Tensor represents an n-dimensional array of values. @@ -168,6 +169,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype, } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ diff --git a/tensorflow/cc/experimental/base/public/tensorhandle.h b/tensorflow/cc/experimental/base/public/tensorhandle.h new file mode 100644 index 00000000000..99453ee7ea8 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/tensorhandle.h @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ + +#include +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// An opaque representation of a tensor computed/managed by the Tensorflow +// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer +// to tensors placed in memory of different devices or remote address spaces. +// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created +// from it. +class TensorHandle { + public: + // Unwraps a Tensor from the given TensorHandle. If an error occurred, + // status->ok() will be false, and the returned Tensor must not be used. + Tensor Resolve(Status* status); + + // Constructs a TensorHandle from a Tensor. If an error occurred, + // status->ok() will be false, and the returned TensorHandle must not be used. + static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime, + Status* status); + + // TensorHandle is movable, and not copyable + TensorHandle(TensorHandle&&) = default; + TensorHandle& operator=(TensorHandle&&) = default; + + private: + // Wraps a TFE_TensorHandle. Takes ownership of handle. + explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {} + + // TensorHandle is not copyable + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + // Returns the underlying TFE_TensorHandle that this object wraps. + // This object retains ownership of the pointer. + TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); } + + // Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle, + // and takes ownership of handle. + void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); } + + struct TFETensorHandleDeleter { + void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); } + }; + std::unique_ptr handle_; +}; + +inline Tensor TensorHandle::Resolve(Status* status) { + TF_Tensor* tensor = + TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus()); + if (!status->ok()) { + return Tensor(nullptr); + } + return Tensor(tensor); +} + +inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor, + const Runtime& runtime, + Status* status) { + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor( + runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus()); + if (!status->ok()) { + return TensorHandle(nullptr); + } + return TensorHandle(tensor_handle); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD index a2b634a70f4..f449d618f72 100644 --- a/tensorflow/cc/experimental/base/tests/BUILD +++ b/tensorflow/cc/experimental/base/tests/BUILD @@ -5,12 +5,22 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "tensor_types_test_util", + testonly = True, + hdrs = ["tensor_types_test_util.h"], + deps = [ + "//tensorflow/c:tf_datatype", + ], +) + tf_cc_test( name = "tensor_test", srcs = [ "tensor_test.cc", ], deps = [ + ":tensor_types_test_util", "//tensorflow/c:tf_datatype", "//tensorflow/cc/experimental/base/public:status", "//tensorflow/cc/experimental/base/public:tensor", @@ -19,3 +29,22 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "tensorhandle_test", + srcs = [ + "tensorhandle_test.cc", + ], + deps = [ + ":tensor_types_test_util", + "//tensorflow/c:tf_datatype", + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:runtime_builder", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/experimental/base/public:tensor", + "//tensorflow/cc/experimental/base/public:tensorhandle", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/cc/experimental/base/tests/tensor_test.cc b/tensorflow/cc/experimental/base/tests/tensor_test.cc index 86a50bac5cd..33f9ab637e8 100644 --- a/tensorflow/cc/experimental/base/tests/tensor_test.cc +++ b/tensorflow/cc/experimental/base/tests/tensor_test.cc @@ -16,69 +16,22 @@ limitations under the License. #include "tensorflow/cc/experimental/base/public/tensor.h" #include - -#include +#include #include "tensorflow/c/tf_datatype.h" +#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" -namespace tensorflow { namespace { -// Each of the following struct types have two members: a kDType that -// corresponds to a TF_Datatype enum value, and a typedef "type" -// of its corresponding C++ type. These types allow us to write Dtype-agnostic -// tests via GoogleTest's TypedTests: -// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests -struct FloatType { - using type = float; - static constexpr TF_DataType kDType = TF_FLOAT; -}; +using tensorflow::experimental::cc::Status; +using tensorflow::experimental::cc::Tensor; -struct DoubleType { - using type = double; - static constexpr TF_DataType kDType = TF_DOUBLE; -}; - -struct Int32Type { - using type = int32_t; - static constexpr TF_DataType kDType = TF_INT32; -}; - -struct UINT8Type { - using type = uint8_t; - static constexpr TF_DataType kDType = TF_UINT8; -}; - -struct INT8Type { - using type = int8_t; - static constexpr TF_DataType kDType = TF_INT8; -}; - -struct INT64Type { - using type = int64_t; - static constexpr TF_DataType kDType = TF_INT64; -}; - -struct UINT16Type { - using type = uint16_t; - static constexpr TF_DataType kDType = TF_UINT16; -}; - -struct UINT32Type { - using type = uint32_t; - static constexpr TF_DataType kDType = TF_UINT32; -}; - -struct UINT64Type { - using type = uint64_t; - static constexpr TF_DataType kDType = TF_UINT64; -}; - -using SimpleTypes = - ::testing::Types; +using SimpleTypes = ::testing::Types< + tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type, + tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type, + tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>; template class ConstructScalarTensorTest : public ::testing::Test {}; @@ -88,14 +41,13 @@ TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes); // and verifies the expected dimensions, dtype, value, number of bytes, and // number of elements. TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) { - cc::Status status; + Status status; TF_DataType dtype = TypeParam::kDType; typename TypeParam::type value = 42; - cc::Tensor tensor = - cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, - /*data=*/&value, - /*len=*/sizeof(value), - /*deleter=*/[](void*, size_t) {}, &status); + Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, + /*data=*/&value, + /*len=*/sizeof(value), + /*deleter=*/[](void*, size_t) {}, &status); ASSERT_TRUE(status.ok()) << status.message(); EXPECT_EQ(tensor.dims(), 0); @@ -113,7 +65,7 @@ TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes); // and verifies the expected dimensions, dtype, value, number of bytes, and // number of elements. TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { - cc::Status status; + Status status; TF_DataType dtype = TypeParam::kDType; // This is our 1D tensor of varying dtype. std::vector value = {42, 100, 0, 1, 4, 29}; @@ -121,7 +73,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { std::vector shape; shape.push_back(value.size()); - cc::Tensor tensor = cc::Tensor::FromBuffer( + Tensor tensor = Tensor::FromBuffer( /*dtype=*/dtype, /*shape=*/shape, /*data=*/value.data(), /*len=*/value.size() * sizeof(typename TypeParam::type), @@ -130,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { EXPECT_EQ(tensor.dims(), 1); EXPECT_EQ(tensor.dtype(), dtype); - gtl::ArraySlice tensor_view( + tensorflow::gtl::ArraySlice tensor_view( reinterpret_cast(tensor.data()), value.size()); EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[1], 100); @@ -152,14 +104,14 @@ TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes); // and verifies the expected dimensions, dtype, value, number of bytes, and // number of elements. TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { - cc::Status status; + Status status; TF_DataType dtype = TypeParam::kDType; // This is our 1D tensor of varying dtype. std::vector value = {42, 100, 0, 1, 4, 29}; // Shape is Rank 2 vector with shape 2 x 3. std::vector shape({2, 3}); - cc::Tensor tensor = cc::Tensor::FromBuffer( + Tensor tensor = Tensor::FromBuffer( /*dtype=*/dtype, /*shape=*/shape, /*data=*/value.data(), /*len=*/value.size() * sizeof(typename TypeParam::type), @@ -169,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { EXPECT_EQ(tensor.dims(), 2); EXPECT_EQ(tensor.dtype(), dtype); - gtl::ArraySlice tensor_view( + tensorflow::gtl::ArraySlice tensor_view( reinterpret_cast(tensor.data()), value.size()); EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[1], 100); @@ -185,22 +137,22 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { TEST(CPPTensorAPI, ConstructTensorFromBuffer) { bool done = false; - cc::Status status; + Status status; std::vector data_vector({12, 14, 20, 18, 39, 42, 100}); { // data_vector is a rank 1 tensor. std::vector shape; shape.push_back(data_vector.size()); - cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) { + Tensor::DeleterCallback callback = [&done](void* data, size_t len) { done = true; }; - cc::Tensor tensor = - cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape, - /*data=*/data_vector.data(), - /*len=*/data_vector.size() * sizeof(int32_t), - /*deleter=*/callback, &status); + Tensor tensor = + Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape, + /*data=*/data_vector.data(), + /*len=*/data_vector.size() * sizeof(int32_t), + /*deleter=*/callback, &status); ASSERT_TRUE(status.ok()) << status.message(); } // At this point, tensor has been destroyed, and the deleter callback should @@ -209,4 +161,3 @@ TEST(CPPTensorAPI, ConstructTensorFromBuffer) { } } // namespace -} // namespace tensorflow diff --git a/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h new file mode 100644 index 00000000000..af9cad7529b --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ + +#include + +#include "tensorflow/c/tf_datatype.h" + +namespace tensorflow { + +// Each of the following struct types have two members: a kDType that +// corresponds to a TF_Datatype enum value, and a typedef "type" +// of its corresponding C++ type. These types allow us to write Dtype-agnostic +// tests via GoogleTest's TypedTests: +// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests +struct FloatType { + using type = float; + static constexpr TF_DataType kDType = TF_FLOAT; +}; + +struct DoubleType { + using type = double; + static constexpr TF_DataType kDType = TF_DOUBLE; +}; + +struct Int32Type { + using type = int32_t; + static constexpr TF_DataType kDType = TF_INT32; +}; + +struct UINT8Type { + using type = uint8_t; + static constexpr TF_DataType kDType = TF_UINT8; +}; + +struct INT8Type { + using type = int8_t; + static constexpr TF_DataType kDType = TF_INT8; +}; + +struct INT64Type { + using type = int64_t; + static constexpr TF_DataType kDType = TF_INT64; +}; + +struct UINT16Type { + using type = uint16_t; + static constexpr TF_DataType kDType = TF_UINT16; +}; + +struct UINT32Type { + using type = uint32_t; + static constexpr TF_DataType kDType = TF_UINT32; +}; + +struct UINT64Type { + using type = uint64_t; + static constexpr TF_DataType kDType = TF_UINT64; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ diff --git a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc new file mode 100644 index 00000000000..cfeaba4e392 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/experimental/base/public/tensorhandle.h" + +#include +#include + +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/runtime_builder.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" +#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using tensorflow::experimental::cc::Runtime; +using tensorflow::experimental::cc::RuntimeBuilder; +using tensorflow::experimental::cc::Status; +using tensorflow::experimental::cc::Tensor; +using tensorflow::experimental::cc::TensorHandle; + +using SimpleTypes = ::testing::Types< + tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type, + tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type, + tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>; + +template +class ConstructScalarTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes); + +// This test constructs a scalar tensor for each of the types in "SimpleTypes", +// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and +// verify the expected dims, dtype, value, num bytes, and num elements. +TYPED_TEST(ConstructScalarTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + typename TypeParam::type value = 42; + Tensor original_tensor = + Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, + /*data=*/&value, + /*len=*/sizeof(value), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 0); + EXPECT_EQ(tensor.dtype(), dtype); + EXPECT_EQ(*reinterpret_cast(tensor.data()), 42); + EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), 1); +} + +template +class Construct1DTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes); + +// This test constructs a 1D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct1DTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 1 vector. + std::vector shape; + shape.push_back(value.size()); + + Tensor original_tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 1); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +template +class Construct2DTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes); + +// This test constructs a 2D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct2DTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 2 vector with shape 2 x 3. + std::vector shape({2, 3}); + + Tensor original_tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 2); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function.h b/tensorflow/cc/saved_model/experimental/public/concrete_function.h index f57ba052f1a..1adaf70b01a 100644 --- a/tensorflow/cc/saved_model/experimental/public/concrete_function.h +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/experimental/public/function_metadata.h" namespace tensorflow { +namespace experimental { namespace cc { // ConcreteFunction is an executable "function" loaded from a SavedModelAPI. @@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h index bab95278eac..88cb779ef15 100644 --- a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" namespace tensorflow { +namespace experimental { namespace cc { // ConcreteFunctionList helps convert an opaque pointer to an array of @@ -56,6 +57,7 @@ inline std::vector ConcreteFunctionList::ToVector() { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/function_metadata.h b/tensorflow/cc/saved_model/experimental/public/function_metadata.h index c3dcc45af0e..11e1a860d84 100644 --- a/tensorflow/cc/saved_model/experimental/public/function_metadata.h +++ b/tensorflow/cc/saved_model/experimental/public/function_metadata.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" namespace tensorflow { +namespace experimental { namespace cc { // FunctionMetadata stores additional function information, including @@ -40,6 +41,7 @@ class FunctionMetadata final { }; } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h index 814479de213..04018bf2aab 100644 --- a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h +++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" namespace tensorflow { +namespace experimental { namespace cc { // SavedModelAPI offers a way to load Tensorflow Saved Models @@ -155,6 +156,7 @@ inline std::vector SavedModelAPI::ListFunctions() { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc index 155c58604bf..7f7f6b09a6d 100644 --- a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc +++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc @@ -26,10 +26,14 @@ limitations under the License. #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" -namespace tensorflow { namespace { +using tensorflow::experimental::cc::Runtime; +using tensorflow::experimental::cc::RuntimeBuilder; +using tensorflow::experimental::cc::SavedModelAPI; +using tensorflow::experimental::cc::Status; + constexpr char kTestData[] = "cc/saved_model/testdata"; std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { @@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { class CPPSavedModelAPITest : public ::testing::TestWithParam {}; TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { - cc::Status status; - cc::RuntimeBuilder builder; + Status status; + RuntimeBuilder builder; bool use_tfrt = GetParam(); if (use_tfrt) { GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. } builder.SetUseTFRT(use_tfrt); - std::unique_ptr runtime = builder.Build(&status); + std::unique_ptr runtime = builder.Build(&status); ASSERT_TRUE(status.ok()) << status.message(); std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); std::unordered_set tags = {"serve"}; - std::unique_ptr model = - cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags); + std::unique_ptr model = + SavedModelAPI::Load(model_dir, *runtime, &status, &tags); // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. // That unblocks writing other tests that require a TF_SavedModel*, @@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { } TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { - cc::Status status; - cc::RuntimeBuilder builder; + Status status; + RuntimeBuilder builder; bool use_tfrt = GetParam(); if (use_tfrt) { GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. } builder.SetUseTFRT(use_tfrt); - std::unique_ptr runtime = builder.Build(&status); + std::unique_ptr runtime = builder.Build(&status); ASSERT_TRUE(status.ok()) << status.message(); std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); - std::unique_ptr model = - cc::SavedModelAPI::Load(model_dir, *runtime, &status); + std::unique_ptr model = + SavedModelAPI::Load(model_dir, *runtime, &status); // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. // That unblocks writing other tests that require a TF_SavedModel*, @@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests, } // namespace -} // namespace tensorflow