Continue splitting SignatureDefFunction from ConcreteFunction.

PiperOrigin-RevId: 328682987
Change-Id: I4d32d48bbfb5f461c5e3edc50513057808c8c47f
This commit is contained in:
Brian Zhao 2020-08-26 23:33:17 -07:00 committed by TensorFlower Gardener
parent c5e13d84ca
commit ecc6fe9066
3 changed files with 209 additions and 0 deletions

View File

@ -81,3 +81,26 @@ cc_library(
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tf_signature_def_function",
srcs = [
"tf_signature_def_function.cc",
],
hdrs = [
"tf_signature_def_function.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)

View File

@ -0,0 +1,96 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
#include <memory>
#include <string>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
TFSignatureDefFunction::TFSignatureDefFunction(
const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx)
: name_(name),
captures_(std::move(captures)),
metadata_(std::move(metadata)),
ctx_(ctx) {}
TFSignatureDefFunction::~TFSignatureDefFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
Status TFSignatureDefFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx,
std::unique_ptr<TFSignatureDefFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
out->reset(new TFSignatureDefFunction(function_def->signature().name(),
std::move(captures),
std::move(metadata), ctx));
return Status();
}
const SignatureDefFunctionMetadata&
TFSignatureDefFunction::GetFunctionMetadata() const {
return metadata_;
}
Status TFSignatureDefFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// This is the TF eager runtime implementation of SignatureDefFunction (separate
// from the TFRT implementation). The user-facing API of SignatureDefFunctions
// and their semantic differences from ConcreteFunction are described here:
// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59
// Additional implementation notes are available here:
// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48
class TFSignatureDefFunction : public SignatureDefFunction {
public:
// Factory function for creating a TFSignatureDefFunction.
//
// Params:
// function_def - The function_def associated with the created
// TFSignatureDefFunction. TFSignatureDefFunction will
// register this function_def with `ctx` on creation, and
// de-register it on destruction. function_def must be
// non-null, but otherwise has no lifetime requirements.
// captures - The captured TensorHandles associated with this
// TFConcreteFunction.
// metadata - FunctionMetadata associated with this TFSignatureDefFunction.
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
// outlive TFSignatureDefFunction.
// out - The output TFSignatureDefFunction.
static Status Create(const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
SignatureDefFunctionMetadata metadata,
ImmediateExecutionContext* ctx,
std::unique_ptr<TFSignatureDefFunction>* out);
// This method creates a "Call" Op used to execute the function.
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const override;
const SignatureDefFunctionMetadata& GetFunctionMetadata() const override;
~TFSignatureDefFunction() override;
private:
TFSignatureDefFunction(const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
SignatureDefFunctionMetadata metadata,
ImmediateExecutionContext* ctx);
TFSignatureDefFunction(const TFSignatureDefFunction&) = delete;
TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete;
// Name of the FunctionDef corresponding to this TFSignatureDefFunction
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
SignatureDefFunctionMetadata metadata_;
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_