diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index dbe1b6d656c..bc9a5fd9442 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -3,6 +3,10 @@ # Targets in this directory are pure C++ "Classes" underlying the C API types # under tf/c/experimental/saved_model/public/. They are subject to change and # have visibility limited to Tensorflow's implementation only. +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) package( default_visibility = [ @@ -47,6 +51,22 @@ cc_library( ], ) +cc_library( + name = "saved_model_utils", + srcs = [ + "saved_model_utils.cc", + ], + hdrs = [ + "saved_model_utils.h", + ], + deps = [ + "//tensorflow/c:tf_tensor_internal", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "tf_saved_model_impl", srcs = [ @@ -84,3 +104,26 @@ filegroup( ], visibility = ["//tensorflow/core:__pkg__"], ) + +tf_cc_test( + name = "saved_model_utils_test", + srcs = [ + "saved_model_utils_test.cc", + ], + deps = [ + ":saved_model_utils", + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:core_cpu_lib", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:core", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD new file mode 100644 index 00000000000..ad3844e00a0 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -0,0 +1,39 @@ +# This package contains classes corresponding to Revived SavedObjectGraph types +# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62 +package( + default_visibility = [ + # Restricting visibility for now + "//tensorflow/c/experimental/saved_model/core:__pkg__", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "constant", + srcs = [ + "constant.cc", + ], + hdrs = [ + "constant.h", + ], + deps = [ + ":tensorhandle_convertible", + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:tensor_handle", + ], +) + +cc_library( + name = "tensorhandle_convertible", + hdrs = [ + "tensorhandle_convertible.h", + ], + deps = [ + "//tensorflow/c/eager:immediate_execution_tensor_handle", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/constant.cc b/tensorflow/c/experimental/saved_model/core/revived_types/constant.cc new file mode 100644 index 00000000000..0cabf83a123 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/constant.cc @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +Constant::Constant(ImmediateTensorHandlePtr handle) + : TensorHandleConvertible(std::move(handle)) {} + +Status Constant::Create(ImmediateExecutionContext* ctx, + AbstractTensorInterface* tensor, + std::unique_ptr* output) { + ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor); + if (handle == nullptr) { + return errors::Internal("Failed to convert tensor to tensorhandle"); + } + output->reset(new Constant(ImmediateTensorHandlePtr(handle))); + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/constant.h b/tensorflow/c/experimental/saved_model/core/revived_types/constant.h new file mode 100644 index 00000000000..845a6f391c0 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/constant.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { + +// This class corresponds to python's tf.constant, which is effectively a +// TensorHandle explicitly initialized to some value. +// For now this doesn't do much beyond wrap Context's CreateLocalHandle method, +// and offer a subclass of TensorHandleConvertible. Note that similar to +// the python's eager mode logic, we bypass calling the "Const" op: +// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301 +class Constant : public TensorHandleConvertible { + public: + static Status Create(ImmediateExecutionContext* ctx, + AbstractTensorInterface* tensor, + std::unique_ptr* output); + + // RevivedConstant is movable, but not copyable. + Constant(Constant&& other) = default; + Constant& operator=(Constant&& other) = default; + + ~Constant() override = default; + + private: + explicit Constant(ImmediateTensorHandlePtr handle); + Constant(const Constant&) = delete; + Constant& operator=(const Constant&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h b/tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h new file mode 100644 index 00000000000..98179586e83 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_ + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" + +namespace tensorflow { + +// A common interface for objects that can be converted to a TensorHandle. +// Examples of objects that implement this include Variables, Constants, Assets, +// etc. This is used to convert captured objects into a ConcreteFunction's +// captured TensorHandles: +// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240 +class TensorHandleConvertible { + public: + explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle) + : handle_(std::move(handle)) {} + + ImmediateExecutionTensorHandle* handle() { return handle_.get(); } + + // TensorHandleConvertible is movable, but not copyable. + TensorHandleConvertible(TensorHandleConvertible&& other) = default; + TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default; + + virtual ~TensorHandleConvertible() = default; + + protected: + TensorHandleConvertible(const TensorHandleConvertible&) = delete; + TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete; + ImmediateTensorHandlePtr handle_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc new file mode 100644 index 00000000000..9fe9caa27d7 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" + +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/tf_tensor_internal.h" + +namespace tensorflow { +namespace internal { + +Status TensorProtoToConstant(ImmediateExecutionContext* ctx, + const TensorProto& proto, + std::unique_ptr* output) { + tensorflow::Tensor tensor; + bool parse_result = tensor.FromProto(proto); + if (!parse_result) { + return errors::Internal("Failed to parse tensor from tensorproto"); + } + + TensorInterface tensor_interface(std::move(tensor)); + return Constant::Create(ctx, &tensor_interface, output); +} + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h new file mode 100644 index 00000000000..5223f1c5f7d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ + +// Some internal utility functions for the SavedModelAPI, factored out into a +// separately unit-testable header. + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { +namespace internal { + +// Load a TensorProto into a tensorflow::Constant. This is similar to the +// constant loading logic in python: +// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437 +Status TensorProtoToConstant(ImmediateExecutionContext* ctx, + const TensorProto& proto, + std::unique_ptr* output); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils_test.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils_test.cc new file mode 100644 index 00000000000..483162574f7 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils_test.cc @@ -0,0 +1,199 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" + +#include + +#include +#include + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// Converts a tensorflow::DatatypeSet to std::vector. +// This is needed for GTest's ::testing::ValuesIn, since +// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable. +std::vector DataTypeSetToVector(DataTypeSet set) { + std::vector result; + result.reserve(set.size()); + for (DataType dt : set) { + result.push_back(dt); + } + return result; +} + +// Returns a vector of shapes intended to be "interesting" test cases. +std::vector> InterestingShapes() { + std::vector> interesting_shapes; + interesting_shapes.push_back({}); // Scalar + interesting_shapes.push_back({10}); // 1D Vector + interesting_shapes.push_back({3, 3}); // 2D Matrix + interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor + return interesting_shapes; +} + +// Fills a numeric tensor with `value`. +void FillNumericTensor(Tensor* tensor, int8 value) { + switch (tensor->dtype()) { +#define CASE(type) \ + case DataTypeToEnum::value: { \ + const auto& flattened = tensor->flat(); \ + for (int i = 0; i < tensor->NumElements(); ++i) { \ + flattened(i) = value; \ + } \ + break; \ + } + TF_CALL_INTEGRAL_TYPES(CASE); + TF_CALL_double(CASE); + TF_CALL_float(CASE); +#undef CASE + default: + CHECK(false) << "Unsupported data type: " + << DataTypeString(tensor->dtype()); + break; + } +} + +// Checks the underlying data is equal for the buffers for two numeric tensors. +// Note: The caller must ensure to check that the dtypes and sizes of the +// underlying buffers are the same before calling this. +void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a, + void* b) { + switch (dtype) { +#define CASE(type) \ + case DataTypeToEnum::value: { \ + type* typed_a = static_cast(a); \ + type* typed_b = static_cast(b); \ + for (int64 i = 0; i < num_elements; ++i) { \ + if (DataTypeIsFloating(dtype)) { \ + EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \ + } else { \ + EXPECT_EQ(typed_a[i], typed_b[i]); \ + } \ + } \ + break; \ + } + TF_CALL_INTEGRAL_TYPES(CASE); + TF_CALL_double(CASE); + TF_CALL_float(CASE); +#undef CASE + default: + CHECK(false) << "Unsupported data type: " << DataTypeString(dtype); + } +} + +class ConstantTest : public ::testing::TestWithParam< + std::tuple, bool>> { + public: + ConstantTest() + : device_mgr_(std::make_unique(DeviceFactory::NewDevice( + "CPU", {}, "/job:localhost/replica:0/task:0"))), + ctx_(new EagerContext( + SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + tensorflow::ContextMirroringPolicy::MIRRORING_NONE, + /* async= */ false, + /* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(), + /* device_mgr_owned= */ false, /* rendezvous= */ nullptr, + /* custom_kernel_creator= */ nullptr, + /* cluster_flr= */ nullptr)) {} + + EagerContext* context() { return ctx_.get(); } + + private: + std::unique_ptr device_mgr_; + EagerContextPtr ctx_; +}; + +// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant +// preserves values. +TEST_P(ConstantTest, CreateConstantSuccessful) { + // Get test parameters + auto& test_params = GetParam(); + DataType dtype = std::get<0>(test_params); + TensorShape shape(std::get<1>(test_params)); + bool tensorproto_use_tensor_content = std::get<2>(test_params); + + // Construct a Tensor with the given dtype + shape + Tensor expected(dtype, shape); + FillNumericTensor(&expected, 42); + + // Serialize it to a Tensorproto + TensorProto proto; + if (tensorproto_use_tensor_content) { + expected.AsProtoTensorContent(&proto); + } else { + expected.AsProtoField(&proto); + } + + // Revival should succeed w/o errors + std::unique_ptr revived; + TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived)); + + // The revived tensorhandle should have the exact same dtype, shape, + + // approx equivalent data to the original. + ImmediateExecutionTensorHandle* handle = revived->handle(); + Status status; + AbstractTensorPtr revived_tensor(handle->Resolve(&status)); + TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor"; + EXPECT_EQ(revived_tensor->Type(), expected.dtype()); + EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements()); + EXPECT_EQ(revived_tensor->NumDims(), expected.dims()); + for (int i = 0; i < expected.dims(); ++i) { + EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i)); + } + + CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(), + revived_tensor->Data(), expected.data()); +} + +// Test against combinations of tensors that are +// 1. Varying dtypes +// 2. Varying shapes +// 3. TensorProto serialized using tensor_content vs repeated type +INSTANTIATE_TEST_SUITE_P( + ConstantIntegerDtypesTest, ConstantTest, + ::testing::Combine( + ::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)), + ::testing::ValuesIn(InterestingShapes()), + ::testing::Values(false, true))); + +INSTANTIATE_TEST_SUITE_P( + ConstantFloatingDtypesTest, ConstantTest, + ::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE), + ::testing::ValuesIn(InterestingShapes()), + ::testing::Values(false, true))); + +} // namespace +} // namespace tensorflow