Adding SignatureDefFunctionMetadata to SignatureDefFunctions. This allows users to verify that the input tensor's dtype + shape match up with the function's expected inputs. Additionally, this is necessary to know the number of TFE_TensorHandles to pre-allocated for TFE_Execute on the SignatureDefFunction.

PiperOrigin-RevId: 334482407
Change-Id: I8961f6f409efecfadbdb077d0cdbcc5c0f31a9b6
This commit is contained in:
Brian Zhao 2020-09-29 16:06:22 -07:00 committed by TensorFlower Gardener
parent 8c7d50fcfd
commit 2d660e99c7
7 changed files with 345 additions and 3 deletions

View File

@ -91,15 +91,24 @@ cc_library(
":signature_def_function_metadata",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "signature_def_function_metadata",
srcs = [
"signature_def_function_metadata.cc",
],
hdrs = [
"signature_def_function_metadata.h",
],
deps = [
":tensor_spec",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
@ -268,6 +277,20 @@ tf_cc_test(
],
)
cc_library(
name = "tensor_spec",
srcs = [
"tensor_spec.cc",
],
hdrs = [
"tensor_spec.h",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "tf_concrete_function_loading_test",
srcs = [

View File

@ -92,6 +92,8 @@ cc_library(
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include "absl/types/span.h"
@ -30,14 +32,26 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace {
using StructuredValueDictEntry =
protobuf::MapPair<std::string, StructuredValue>;
using NamedParamMap =
gtl::FlatMap<StringPiece, const TensorSpecProto*, StringPieceHasher>;
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
const PartiallyRevivedObjects& objects) {
for (const auto& id_and_resource : objects.restored_resources) {
@ -124,6 +138,142 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
}
}
std::vector<SignatureDefParam> SignatureDefParamsFromNamedParamMap(
const NamedParamMap& params) {
// The underlying functiondef associated with the SignatureDef has
// nest.flattened inputs and outputs, which are sorted by string key.
std::vector<SignatureDefParam> result;
result.reserve(params.size());
for (const auto& named_param : params) {
result.push_back(SignatureDefParam(std::string(named_param.first),
TensorSpec(*named_param.second)));
}
std::sort(result.begin(), result.end(),
[](const SignatureDefParam& x, const SignatureDefParam& y) {
return x.name() < y.name();
});
return result;
}
// SignatureDefArgsFromInputs takes the "canonicalized_input_signature"
// field of a SavedConcreteFunction, ensures it conforms to the structure of
// tuple(tuple(), dict<string,TensorSpec>()), and "returns" a list of
// SignatureDefParams of the SignatureDefFunction's arguments.
Status SignatureDefArgsFromInputs(
const StructuredValue& canonicalized_input_signature,
std::vector<SignatureDefParam>* out) {
// Note(bmzhao): canonicalized_input_signature should be a tuple of
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
// string keys to TensorSpecs.
if (!canonicalized_input_signature.has_tuple_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature should be "
"of form tuple(tuple(), dict()), but was instead: \n",
canonicalized_input_signature.DebugString());
}
const TupleValue& args_kwargs_tuple =
canonicalized_input_signature.tuple_value();
if (args_kwargs_tuple.values_size() != 2) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature should be "
"a tuple of two elements (args, kwargs), but was instead: \n",
args_kwargs_tuple.DebugString());
}
const StructuredValue& args = args_kwargs_tuple.values(0);
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's args"
"should be an empty tuple, but instead got: \n",
args.DebugString());
}
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
if (!kwargs.has_dict_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's kwargs"
"should be a dictionary, but instead got: \n",
kwargs.DebugString());
}
const DictValue& kwargs_dict = kwargs.dict_value();
NamedParamMap result;
result.reserve(kwargs_dict.fields_size());
for (const auto& key_value : kwargs_dict.fields()) {
const std::string& key = key_value.first;
const StructuredValue& value = key_value.second;
if (!value.has_tensor_spec_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's kwargs"
"dictionary contained a non-tensorspec value for key-value pair: \n",
"Key: ", key, "Value: \n", value.DebugString());
}
result[key] = &value.tensor_spec_value();
}
*out = SignatureDefParamsFromNamedParamMap(result);
return Status();
}
// SignatureDefReturnsFromOutputs takes the "output_signature" field of a
// SavedConcreteFunction, ensures it conforms to the structure of
// dict<string,TensorSpec>(), and "returns" a list of SignatureDefParams of the
// SignatureDefFunction's returns.
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
std::vector<SignatureDefParam>* out) {
if (!output_signature.has_dict_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's output_signature must be a dictionary, but "
"instead got: ",
output_signature.DebugString());
}
const DictValue& output_dict = output_signature.dict_value();
NamedParamMap result;
result.reserve(output_dict.fields_size());
for (const auto& key_value : output_dict.fields()) {
const std::string& key = key_value.first;
const StructuredValue& value = key_value.second;
if (!value.has_tensor_spec_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's output_signature dictionary contained a "
"non-tensorspec value for key-value pair: \n",
"Key: ", key, "Value: \n", value.DebugString());
}
result[key] = &value.tensor_spec_value();
}
*out = SignatureDefParamsFromNamedParamMap(result);
return Status();
}
// The implementation takes advantage of the fact that SignatureDefFunction's
// "traced" Signature wrapper function always has inputs/outputs of dictionaries
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178
// Additionally, we take advantage of the fact that the SignatureDefFunction's
// associated functiondef has lexicographically ordered inputs/outputs due to
// nest.flatten.
Status LoadSignatureDefFunctionMetadata(
const SavedConcreteFunction& saved_concrete_function,
SignatureDefFunctionMetadata* out) {
std::vector<SignatureDefParam> args;
TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs(
saved_concrete_function.canonicalized_input_signature(), &args));
std::vector<SignatureDefParam> rets;
TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs(
saved_concrete_function.output_signature(), &rets));
*out = SignatureDefFunctionMetadata(std::move(args), std::move(rets));
return Status();
}
// This function finds the necessary captures, then forwards to the builder
// method
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
@ -162,10 +312,14 @@ Status CreateSignatureDefFunction(
&capture_handle));
captures.push_back(capture_handle);
}
// TODO(bmzhao): Create Metadata here
SignatureDefFunctionMetadata metadata;
TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata(
*builder.saved_concrete_func, &metadata));
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
/*captures=*/std::move(captures),
/*metadata=*/{},
/*metadata=*/std::move(metadata),
/*ctx=*/ctx,
/*out=*/out);
}

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
namespace tensorflow {
SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec)
: name_(std::move(name)), spec_(std::move(spec)) {}
const std::string& SignatureDefParam::name() const { return name_; }
const TensorSpec& SignatureDefParam::spec() const { return spec_; }
SignatureDefFunctionMetadata::SignatureDefFunctionMetadata(
std::vector<SignatureDefParam> arguments,
std::vector<SignatureDefParam> returns)
: arguments_(std::move(arguments)), returns_(std::move(returns)) {}
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::arguments()
const {
return arguments_;
}
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::returns()
const {
return returns_;
}
} // namespace tensorflow

View File

@ -16,10 +16,42 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
#include <string>
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
// SignatureDefParam represents a named Tensor input or output to a
// SignatureDefFunction.
class SignatureDefParam {
public:
SignatureDefParam(std::string name, TensorSpec spec);
const std::string& name() const;
const TensorSpec& spec() const;
private:
std::string name_;
TensorSpec spec_;
};
class SignatureDefFunctionMetadata {
// TODO(bmzhao): Fill in with fields as necessary
public:
SignatureDefFunctionMetadata() = default;
SignatureDefFunctionMetadata(std::vector<SignatureDefParam> arguments,
std::vector<SignatureDefParam> returns);
const std::vector<SignatureDefParam>& arguments() const;
const std::vector<SignatureDefParam>& returns() const;
private:
std::vector<SignatureDefParam> arguments_;
std::vector<SignatureDefParam> returns_;
};
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include <initializer_list>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
TensorSpec::TensorSpec()
: shape_(std::initializer_list<int64>()), dtype_(DT_FLOAT) {}
TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype)
: shape_(std::move(shape)), dtype_(dtype) {}
TensorSpec::TensorSpec(const TensorSpecProto& proto)
: shape_(proto.shape()), dtype_(proto.dtype()) {}
const PartialTensorShape& TensorSpec::shape() const { return shape_; }
DataType TensorSpec::dtype() const { return dtype_; }
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
// Note(bmzhao): TensorSpec deliberately does not store the "name" from a
// TensorSpecProto. From edloper@, "Names should really be associated with
// parameters, not the tensors inside those parameters. This would be
// inconsistent with the corresponding Python class, but I don't think that's
// necessarily a problem. If it turns out later that we really need a name
// attribute here, we can always add it back in; but let's see how far we can
// get without it."
class TensorSpec {
public:
// Constructs a scalar, DT_FLOAT TensorSpec
TensorSpec();
TensorSpec(PartialTensorShape shape, DataType dtype);
explicit TensorSpec(const TensorSpecProto& proto);
const PartialTensorShape& shape() const;
DataType dtype() const;
private:
PartialTensorShape shape_;
DataType dtype_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_