diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index bc532506a46..4cf868e4714 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -91,15 +91,24 @@ cc_library( ":signature_def_function_metadata", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/types:span", ], ) cc_library( name = "signature_def_function_metadata", + srcs = [ + "signature_def_function_metadata.cc", + ], hdrs = [ "signature_def_function_metadata.h", ], + deps = [ + ":tensor_spec", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], ) cc_library( @@ -268,6 +277,20 @@ tf_cc_test( ], ) +cc_library( + name = "tensor_spec", + srcs = [ + "tensor_spec.cc", + ], + hdrs = [ + "tensor_spec.h", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], +) + tf_cc_test( name = "tf_concrete_function_loading_test", srcs = [ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index eaccc2fac69..257fe9ed801 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -92,6 +92,8 @@ cc_library( "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/c/experimental/saved_model/core:tensor_spec", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc index 5cc06e6c54f..e3ca2ae622f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" +#include #include +#include #include #include "absl/types/span.h" @@ -30,14 +32,26 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" namespace tensorflow { namespace { +using StructuredValueDictEntry = + protobuf::MapPair; + +using NamedParamMap = + gtl::FlatMap; + Status AssertAllCreateResourceFunctionsHaveNoCaptures( const PartiallyRevivedObjects& objects) { for (const auto& id_and_resource : objects.restored_resources) { @@ -124,6 +138,142 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, } } +std::vector SignatureDefParamsFromNamedParamMap( + const NamedParamMap& params) { + // The underlying functiondef associated with the SignatureDef has + // nest.flattened inputs and outputs, which are sorted by string key. + std::vector result; + result.reserve(params.size()); + for (const auto& named_param : params) { + result.push_back(SignatureDefParam(std::string(named_param.first), + TensorSpec(*named_param.second))); + } + std::sort(result.begin(), result.end(), + [](const SignatureDefParam& x, const SignatureDefParam& y) { + return x.name() < y.name(); + }); + + return result; +} + +// SignatureDefArgsFromInputs takes the "canonicalized_input_signature" +// field of a SavedConcreteFunction, ensures it conforms to the structure of +// tuple(tuple(), dict()), and "returns" a list of +// SignatureDefParams of the SignatureDefFunction's arguments. +Status SignatureDefArgsFromInputs( + const StructuredValue& canonicalized_input_signature, + std::vector* out) { + // Note(bmzhao): canonicalized_input_signature should be a tuple of + // (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of + // string keys to TensorSpecs. + if (!canonicalized_input_signature.has_tuple_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature should be " + "of form tuple(tuple(), dict()), but was instead: \n", + canonicalized_input_signature.DebugString()); + } + + const TupleValue& args_kwargs_tuple = + canonicalized_input_signature.tuple_value(); + if (args_kwargs_tuple.values_size() != 2) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature should be " + "a tuple of two elements (args, kwargs), but was instead: \n", + args_kwargs_tuple.DebugString()); + } + + const StructuredValue& args = args_kwargs_tuple.values(0); + if (!args.has_tuple_value() || !args.tuple_value().values().empty()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's args" + "should be an empty tuple, but instead got: \n", + args.DebugString()); + } + + const StructuredValue& kwargs = args_kwargs_tuple.values(1); + if (!kwargs.has_dict_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's kwargs" + "should be a dictionary, but instead got: \n", + kwargs.DebugString()); + } + + const DictValue& kwargs_dict = kwargs.dict_value(); + NamedParamMap result; + result.reserve(kwargs_dict.fields_size()); + + for (const auto& key_value : kwargs_dict.fields()) { + const std::string& key = key_value.first; + const StructuredValue& value = key_value.second; + if (!value.has_tensor_spec_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's kwargs" + "dictionary contained a non-tensorspec value for key-value pair: \n", + "Key: ", key, "Value: \n", value.DebugString()); + } + result[key] = &value.tensor_spec_value(); + } + + *out = SignatureDefParamsFromNamedParamMap(result); + + return Status(); +} + +// SignatureDefReturnsFromOutputs takes the "output_signature" field of a +// SavedConcreteFunction, ensures it conforms to the structure of +// dict(), and "returns" a list of SignatureDefParams of the +// SignatureDefFunction's returns. +Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature, + std::vector* out) { + if (!output_signature.has_dict_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's output_signature must be a dictionary, but " + "instead got: ", + output_signature.DebugString()); + } + + const DictValue& output_dict = output_signature.dict_value(); + NamedParamMap result; + result.reserve(output_dict.fields_size()); + + for (const auto& key_value : output_dict.fields()) { + const std::string& key = key_value.first; + const StructuredValue& value = key_value.second; + if (!value.has_tensor_spec_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's output_signature dictionary contained a " + "non-tensorspec value for key-value pair: \n", + "Key: ", key, "Value: \n", value.DebugString()); + } + result[key] = &value.tensor_spec_value(); + } + *out = SignatureDefParamsFromNamedParamMap(result); + + return Status(); +} + +// The implementation takes advantage of the fact that SignatureDefFunction's +// "traced" Signature wrapper function always has inputs/outputs of dictionaries +// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126 +// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178 +// Additionally, we take advantage of the fact that the SignatureDefFunction's +// associated functiondef has lexicographically ordered inputs/outputs due to +// nest.flatten. +Status LoadSignatureDefFunctionMetadata( + const SavedConcreteFunction& saved_concrete_function, + SignatureDefFunctionMetadata* out) { + std::vector args; + TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs( + saved_concrete_function.canonicalized_input_signature(), &args)); + + std::vector rets; + TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs( + saved_concrete_function.output_signature(), &rets)); + + *out = SignatureDefFunctionMetadata(std::move(args), std::move(rets)); + return Status(); +} + // This function finds the necessary captures, then forwards to the builder // method Status CreateConcreteFunction(ImmediateExecutionContext* ctx, @@ -162,10 +312,14 @@ Status CreateSignatureDefFunction( &capture_handle)); captures.push_back(capture_handle); } - // TODO(bmzhao): Create Metadata here + + SignatureDefFunctionMetadata metadata; + TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata( + *builder.saved_concrete_func, &metadata)); + return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef, /*captures=*/std::move(captures), - /*metadata=*/{}, + /*metadata=*/std::move(metadata), /*ctx=*/ctx, /*out=*/out); } diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc new file mode 100644 index 00000000000..4e455f08f49 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +namespace tensorflow { + +SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec) + : name_(std::move(name)), spec_(std::move(spec)) {} + +const std::string& SignatureDefParam::name() const { return name_; } + +const TensorSpec& SignatureDefParam::spec() const { return spec_; } + +SignatureDefFunctionMetadata::SignatureDefFunctionMetadata( + std::vector arguments, + std::vector returns) + : arguments_(std::move(arguments)), returns_(std::move(returns)) {} + +const std::vector& SignatureDefFunctionMetadata::arguments() + const { + return arguments_; +} + +const std::vector& SignatureDefFunctionMetadata::returns() + const { + return returns_; +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h index 5a579676d4e..e9cc0b11b00 100644 --- a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h @@ -16,10 +16,42 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#include +#include + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/struct.pb.h" + namespace tensorflow { +// SignatureDefParam represents a named Tensor input or output to a +// SignatureDefFunction. +class SignatureDefParam { + public: + SignatureDefParam(std::string name, TensorSpec spec); + + const std::string& name() const; + + const TensorSpec& spec() const; + + private: + std::string name_; + TensorSpec spec_; +}; + class SignatureDefFunctionMetadata { - // TODO(bmzhao): Fill in with fields as necessary + public: + SignatureDefFunctionMetadata() = default; + SignatureDefFunctionMetadata(std::vector arguments, + std::vector returns); + + const std::vector& arguments() const; + const std::vector& returns() const; + + private: + std::vector arguments_; + std::vector returns_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tensor_spec.cc b/tensorflow/c/experimental/saved_model/core/tensor_spec.cc new file mode 100644 index 00000000000..4d68ec73b1b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tensor_spec.cc @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +TensorSpec::TensorSpec() + : shape_(std::initializer_list()), dtype_(DT_FLOAT) {} + +TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype) + : shape_(std::move(shape)), dtype_(dtype) {} + +TensorSpec::TensorSpec(const TensorSpecProto& proto) + : shape_(proto.shape()), dtype_(proto.dtype()) {} + +const PartialTensorShape& TensorSpec::shape() const { return shape_; } + +DataType TensorSpec::dtype() const { return dtype_; } + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tensor_spec.h b/tensorflow/c/experimental/saved_model/core/tensor_spec.h new file mode 100644 index 00000000000..dcdff8900bd --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tensor_spec.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +// Note(bmzhao): TensorSpec deliberately does not store the "name" from a +// TensorSpecProto. From edloper@, "Names should really be associated with +// parameters, not the tensors inside those parameters. This would be +// inconsistent with the corresponding Python class, but I don't think that's +// necessarily a problem. If it turns out later that we really need a name +// attribute here, we can always add it back in; but let's see how far we can +// get without it." +class TensorSpec { + public: + // Constructs a scalar, DT_FLOAT TensorSpec + TensorSpec(); + + TensorSpec(PartialTensorShape shape, DataType dtype); + + explicit TensorSpec(const TensorSpecProto& proto); + + const PartialTensorShape& shape() const; + DataType dtype() const; + + private: + PartialTensorShape shape_; + DataType dtype_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_