Adding SignatureDefFunctionMetadata to SignatureDefFunctions. This allows users to verify that the input tensor's dtype + shape match up with the function's expected inputs. Additionally, this is necessary to know the number of TFE_TensorHandles to pre-allocated for TFE_Execute on the SignatureDefFunction.
PiperOrigin-RevId: 334482407 Change-Id: I8961f6f409efecfadbdb077d0cdbcc5c0f31a9b6
This commit is contained in:
parent
8c7d50fcfd
commit
2d660e99c7
@ -91,15 +91,24 @@ cc_library(
|
||||
":signature_def_function_metadata",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_function_metadata",
|
||||
srcs = [
|
||||
"signature_def_function_metadata.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"signature_def_function_metadata.h",
|
||||
],
|
||||
deps = [
|
||||
":tensor_spec",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -268,6 +277,20 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_spec",
|
||||
srcs = [
|
||||
"tensor_spec.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tensor_spec.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tf_concrete_function_loading_test",
|
||||
srcs = [
|
||||
|
@ -92,6 +92,8 @@ cc_library(
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
@ -30,14 +32,26 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
using StructuredValueDictEntry =
|
||||
protobuf::MapPair<std::string, StructuredValue>;
|
||||
|
||||
using NamedParamMap =
|
||||
gtl::FlatMap<StringPiece, const TensorSpecProto*, StringPieceHasher>;
|
||||
|
||||
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
|
||||
const PartiallyRevivedObjects& objects) {
|
||||
for (const auto& id_and_resource : objects.restored_resources) {
|
||||
@ -124,6 +138,142 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<SignatureDefParam> SignatureDefParamsFromNamedParamMap(
|
||||
const NamedParamMap& params) {
|
||||
// The underlying functiondef associated with the SignatureDef has
|
||||
// nest.flattened inputs and outputs, which are sorted by string key.
|
||||
std::vector<SignatureDefParam> result;
|
||||
result.reserve(params.size());
|
||||
for (const auto& named_param : params) {
|
||||
result.push_back(SignatureDefParam(std::string(named_param.first),
|
||||
TensorSpec(*named_param.second)));
|
||||
}
|
||||
std::sort(result.begin(), result.end(),
|
||||
[](const SignatureDefParam& x, const SignatureDefParam& y) {
|
||||
return x.name() < y.name();
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// SignatureDefArgsFromInputs takes the "canonicalized_input_signature"
|
||||
// field of a SavedConcreteFunction, ensures it conforms to the structure of
|
||||
// tuple(tuple(), dict<string,TensorSpec>()), and "returns" a list of
|
||||
// SignatureDefParams of the SignatureDefFunction's arguments.
|
||||
Status SignatureDefArgsFromInputs(
|
||||
const StructuredValue& canonicalized_input_signature,
|
||||
std::vector<SignatureDefParam>* out) {
|
||||
// Note(bmzhao): canonicalized_input_signature should be a tuple of
|
||||
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
|
||||
// string keys to TensorSpecs.
|
||||
if (!canonicalized_input_signature.has_tuple_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||
"of form tuple(tuple(), dict()), but was instead: \n",
|
||||
canonicalized_input_signature.DebugString());
|
||||
}
|
||||
|
||||
const TupleValue& args_kwargs_tuple =
|
||||
canonicalized_input_signature.tuple_value();
|
||||
if (args_kwargs_tuple.values_size() != 2) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||
"a tuple of two elements (args, kwargs), but was instead: \n",
|
||||
args_kwargs_tuple.DebugString());
|
||||
}
|
||||
|
||||
const StructuredValue& args = args_kwargs_tuple.values(0);
|
||||
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature's args"
|
||||
"should be an empty tuple, but instead got: \n",
|
||||
args.DebugString());
|
||||
}
|
||||
|
||||
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
|
||||
if (!kwargs.has_dict_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||
"should be a dictionary, but instead got: \n",
|
||||
kwargs.DebugString());
|
||||
}
|
||||
|
||||
const DictValue& kwargs_dict = kwargs.dict_value();
|
||||
NamedParamMap result;
|
||||
result.reserve(kwargs_dict.fields_size());
|
||||
|
||||
for (const auto& key_value : kwargs_dict.fields()) {
|
||||
const std::string& key = key_value.first;
|
||||
const StructuredValue& value = key_value.second;
|
||||
if (!value.has_tensor_spec_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||
"dictionary contained a non-tensorspec value for key-value pair: \n",
|
||||
"Key: ", key, "Value: \n", value.DebugString());
|
||||
}
|
||||
result[key] = &value.tensor_spec_value();
|
||||
}
|
||||
|
||||
*out = SignatureDefParamsFromNamedParamMap(result);
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
// SignatureDefReturnsFromOutputs takes the "output_signature" field of a
|
||||
// SavedConcreteFunction, ensures it conforms to the structure of
|
||||
// dict<string,TensorSpec>(), and "returns" a list of SignatureDefParams of the
|
||||
// SignatureDefFunction's returns.
|
||||
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
|
||||
std::vector<SignatureDefParam>* out) {
|
||||
if (!output_signature.has_dict_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's output_signature must be a dictionary, but "
|
||||
"instead got: ",
|
||||
output_signature.DebugString());
|
||||
}
|
||||
|
||||
const DictValue& output_dict = output_signature.dict_value();
|
||||
NamedParamMap result;
|
||||
result.reserve(output_dict.fields_size());
|
||||
|
||||
for (const auto& key_value : output_dict.fields()) {
|
||||
const std::string& key = key_value.first;
|
||||
const StructuredValue& value = key_value.second;
|
||||
if (!value.has_tensor_spec_value()) {
|
||||
return errors::FailedPrecondition(
|
||||
"SignatureDefFunction's output_signature dictionary contained a "
|
||||
"non-tensorspec value for key-value pair: \n",
|
||||
"Key: ", key, "Value: \n", value.DebugString());
|
||||
}
|
||||
result[key] = &value.tensor_spec_value();
|
||||
}
|
||||
*out = SignatureDefParamsFromNamedParamMap(result);
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
// The implementation takes advantage of the fact that SignatureDefFunction's
|
||||
// "traced" Signature wrapper function always has inputs/outputs of dictionaries
|
||||
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126
|
||||
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178
|
||||
// Additionally, we take advantage of the fact that the SignatureDefFunction's
|
||||
// associated functiondef has lexicographically ordered inputs/outputs due to
|
||||
// nest.flatten.
|
||||
Status LoadSignatureDefFunctionMetadata(
|
||||
const SavedConcreteFunction& saved_concrete_function,
|
||||
SignatureDefFunctionMetadata* out) {
|
||||
std::vector<SignatureDefParam> args;
|
||||
TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs(
|
||||
saved_concrete_function.canonicalized_input_signature(), &args));
|
||||
|
||||
std::vector<SignatureDefParam> rets;
|
||||
TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs(
|
||||
saved_concrete_function.output_signature(), &rets));
|
||||
|
||||
*out = SignatureDefFunctionMetadata(std::move(args), std::move(rets));
|
||||
return Status();
|
||||
}
|
||||
|
||||
// This function finds the necessary captures, then forwards to the builder
|
||||
// method
|
||||
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
|
||||
@ -162,10 +312,14 @@ Status CreateSignatureDefFunction(
|
||||
&capture_handle));
|
||||
captures.push_back(capture_handle);
|
||||
}
|
||||
// TODO(bmzhao): Create Metadata here
|
||||
|
||||
SignatureDefFunctionMetadata metadata;
|
||||
TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata(
|
||||
*builder.saved_concrete_func, &metadata));
|
||||
|
||||
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
|
||||
/*captures=*/std::move(captures),
|
||||
/*metadata=*/{},
|
||||
/*metadata=*/std::move(metadata),
|
||||
/*ctx=*/ctx,
|
||||
/*out=*/out);
|
||||
}
|
||||
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec)
|
||||
: name_(std::move(name)), spec_(std::move(spec)) {}
|
||||
|
||||
const std::string& SignatureDefParam::name() const { return name_; }
|
||||
|
||||
const TensorSpec& SignatureDefParam::spec() const { return spec_; }
|
||||
|
||||
SignatureDefFunctionMetadata::SignatureDefFunctionMetadata(
|
||||
std::vector<SignatureDefParam> arguments,
|
||||
std::vector<SignatureDefParam> returns)
|
||||
: arguments_(std::move(arguments)), returns_(std::move(returns)) {}
|
||||
|
||||
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::arguments()
|
||||
const {
|
||||
return arguments_;
|
||||
}
|
||||
|
||||
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::returns()
|
||||
const {
|
||||
return returns_;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -16,10 +16,42 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// SignatureDefParam represents a named Tensor input or output to a
|
||||
// SignatureDefFunction.
|
||||
class SignatureDefParam {
|
||||
public:
|
||||
SignatureDefParam(std::string name, TensorSpec spec);
|
||||
|
||||
const std::string& name() const;
|
||||
|
||||
const TensorSpec& spec() const;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
TensorSpec spec_;
|
||||
};
|
||||
|
||||
class SignatureDefFunctionMetadata {
|
||||
// TODO(bmzhao): Fill in with fields as necessary
|
||||
public:
|
||||
SignatureDefFunctionMetadata() = default;
|
||||
SignatureDefFunctionMetadata(std::vector<SignatureDefParam> arguments,
|
||||
std::vector<SignatureDefParam> returns);
|
||||
|
||||
const std::vector<SignatureDefParam>& arguments() const;
|
||||
const std::vector<SignatureDefParam>& returns() const;
|
||||
|
||||
private:
|
||||
std::vector<SignatureDefParam> arguments_;
|
||||
std::vector<SignatureDefParam> returns_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
38
tensorflow/c/experimental/saved_model/core/tensor_spec.cc
Normal file
38
tensorflow/c/experimental/saved_model/core/tensor_spec.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TensorSpec::TensorSpec()
|
||||
: shape_(std::initializer_list<int64>()), dtype_(DT_FLOAT) {}
|
||||
|
||||
TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype)
|
||||
: shape_(std::move(shape)), dtype_(dtype) {}
|
||||
|
||||
TensorSpec::TensorSpec(const TensorSpecProto& proto)
|
||||
: shape_(proto.shape()), dtype_(proto.dtype()) {}
|
||||
|
||||
const PartialTensorShape& TensorSpec::shape() const { return shape_; }
|
||||
|
||||
DataType TensorSpec::dtype() const { return dtype_; }
|
||||
|
||||
} // namespace tensorflow
|
51
tensorflow/c/experimental/saved_model/core/tensor_spec.h
Normal file
51
tensorflow/c/experimental/saved_model/core/tensor_spec.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note(bmzhao): TensorSpec deliberately does not store the "name" from a
|
||||
// TensorSpecProto. From edloper@, "Names should really be associated with
|
||||
// parameters, not the tensors inside those parameters. This would be
|
||||
// inconsistent with the corresponding Python class, but I don't think that's
|
||||
// necessarily a problem. If it turns out later that we really need a name
|
||||
// attribute here, we can always add it back in; but let's see how far we can
|
||||
// get without it."
|
||||
class TensorSpec {
|
||||
public:
|
||||
// Constructs a scalar, DT_FLOAT TensorSpec
|
||||
TensorSpec();
|
||||
|
||||
TensorSpec(PartialTensorShape shape, DataType dtype);
|
||||
|
||||
explicit TensorSpec(const TensorSpecProto& proto);
|
||||
|
||||
const PartialTensorShape& shape() const;
|
||||
DataType dtype() const;
|
||||
|
||||
private:
|
||||
PartialTensorShape shape_;
|
||||
DataType dtype_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
Loading…
Reference in New Issue
Block a user