Continue splitting SignatureDefFunction from ConcreteFunction.
PiperOrigin-RevId: 328682987 Change-Id: I4d32d48bbfb5f461c5e3edc50513057808c8c47f
This commit is contained in:
parent
c5e13d84ca
commit
ecc6fe9066
@ -81,3 +81,26 @@ cc_library(
|
|||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_signature_def_function",
|
||||||
|
srcs = [
|
||||||
|
"tf_signature_def_function.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"tf_signature_def_function.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensorhandle_convertible",
|
||||||
|
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_operation",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TFSignatureDefFunction::TFSignatureDefFunction(
|
||||||
|
const std::string& name,
|
||||||
|
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||||
|
SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx)
|
||||||
|
: name_(name),
|
||||||
|
captures_(std::move(captures)),
|
||||||
|
metadata_(std::move(metadata)),
|
||||||
|
ctx_(ctx) {}
|
||||||
|
|
||||||
|
TFSignatureDefFunction::~TFSignatureDefFunction() {
|
||||||
|
Status status = ctx_->RemoveFunction(name_);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TFSignatureDefFunction::Create(
|
||||||
|
const FunctionDef* function_def,
|
||||||
|
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||||
|
SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx,
|
||||||
|
std::unique_ptr<TFSignatureDefFunction>* out) {
|
||||||
|
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
|
||||||
|
out->reset(new TFSignatureDefFunction(function_def->signature().name(),
|
||||||
|
std::move(captures),
|
||||||
|
std::move(metadata), ctx));
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
const SignatureDefFunctionMetadata&
|
||||||
|
TFSignatureDefFunction::GetFunctionMetadata() const {
|
||||||
|
return metadata_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TFSignatureDefFunction::MakeCallOp(
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
|
||||||
|
out->reset(ctx_->CreateOperation());
|
||||||
|
// In eager mode, TF2 python executes functions by constructing an op with
|
||||||
|
// the name of the functiondef:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
|
||||||
|
// In graph mode, we create a PartitionedCallOp instead:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
|
||||||
|
|
||||||
|
// TODO(bmzhao): After discussing with Allen, we should execute this via a
|
||||||
|
// PartitionedCallOp for compatibility with "tooling that assumes functions in
|
||||||
|
// graphs are PartitionedCallOps".
|
||||||
|
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
|
||||||
|
|
||||||
|
// Adding the user-provided inputs to the function.
|
||||||
|
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
|
||||||
|
|
||||||
|
absl::Span<AbstractTensorHandle* const> captures(
|
||||||
|
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
|
||||||
|
captures_.size());
|
||||||
|
|
||||||
|
// Adding the captures of the function.
|
||||||
|
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,90 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// This is the TF eager runtime implementation of SignatureDefFunction (separate
|
||||||
|
// from the TFRT implementation). The user-facing API of SignatureDefFunctions
|
||||||
|
// and their semantic differences from ConcreteFunction are described here:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59
|
||||||
|
// Additional implementation notes are available here:
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48
|
||||||
|
class TFSignatureDefFunction : public SignatureDefFunction {
|
||||||
|
public:
|
||||||
|
// Factory function for creating a TFSignatureDefFunction.
|
||||||
|
//
|
||||||
|
// Params:
|
||||||
|
// function_def - The function_def associated with the created
|
||||||
|
// TFSignatureDefFunction. TFSignatureDefFunction will
|
||||||
|
// register this function_def with `ctx` on creation, and
|
||||||
|
// de-register it on destruction. function_def must be
|
||||||
|
// non-null, but otherwise has no lifetime requirements.
|
||||||
|
// captures - The captured TensorHandles associated with this
|
||||||
|
// TFConcreteFunction.
|
||||||
|
// metadata - FunctionMetadata associated with this TFSignatureDefFunction.
|
||||||
|
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
|
||||||
|
// outlive TFSignatureDefFunction.
|
||||||
|
// out - The output TFSignatureDefFunction.
|
||||||
|
static Status Create(const FunctionDef* function_def,
|
||||||
|
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||||
|
SignatureDefFunctionMetadata metadata,
|
||||||
|
ImmediateExecutionContext* ctx,
|
||||||
|
std::unique_ptr<TFSignatureDefFunction>* out);
|
||||||
|
|
||||||
|
// This method creates a "Call" Op used to execute the function.
|
||||||
|
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
ImmediateOpPtr* out) const override;
|
||||||
|
|
||||||
|
const SignatureDefFunctionMetadata& GetFunctionMetadata() const override;
|
||||||
|
|
||||||
|
~TFSignatureDefFunction() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TFSignatureDefFunction(const std::string& name,
|
||||||
|
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||||
|
SignatureDefFunctionMetadata metadata,
|
||||||
|
ImmediateExecutionContext* ctx);
|
||||||
|
|
||||||
|
TFSignatureDefFunction(const TFSignatureDefFunction&) = delete;
|
||||||
|
TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete;
|
||||||
|
|
||||||
|
// Name of the FunctionDef corresponding to this TFSignatureDefFunction
|
||||||
|
std::string name_;
|
||||||
|
std::vector<ImmediateExecutionTensorHandle*> captures_;
|
||||||
|
SignatureDefFunctionMetadata metadata_;
|
||||||
|
ImmediateExecutionContext* ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
|
Loading…
Reference in New Issue
Block a user