Refactor SignatureDefFunction to use FlatTensorFunction.

PiperOrigin-RevId: 328684886
Change-Id: I9b68f54008c8effcd0b41d41ad3e8cabda88e617
This commit is contained in:
Brian Zhao 2020-08-26 23:56:54 -07:00 committed by TensorFlower Gardener
parent 1ca6da51de
commit e9c732e847
3 changed files with 16 additions and 53 deletions

View File

@ -111,7 +111,7 @@ cc_library(
"tf_signature_def_function.h", "tf_signature_def_function.h",
], ],
deps = [ deps = [
":tensorhandle_convertible", ":flat_tensor_function",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
@ -34,31 +34,20 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
TFSignatureDefFunction::TFSignatureDefFunction( TFSignatureDefFunction::TFSignatureDefFunction(
const std::string& name, std::unique_ptr<FlatTensorFunction> func,
std::vector<ImmediateExecutionTensorHandle*> captures, SignatureDefFunctionMetadata metadata)
SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx) : func_(std::move(func)), metadata_(std::move(metadata)) {}
: name_(name),
captures_(std::move(captures)),
metadata_(std::move(metadata)),
ctx_(ctx) {}
TFSignatureDefFunction::~TFSignatureDefFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
Status TFSignatureDefFunction::Create( Status TFSignatureDefFunction::Create(
const FunctionDef* function_def, const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures, std::vector<ImmediateExecutionTensorHandle*> captures,
SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx, SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx,
std::unique_ptr<TFSignatureDefFunction>* out) { std::unique_ptr<TFSignatureDefFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); std::unique_ptr<FlatTensorFunction> func;
out->reset(new TFSignatureDefFunction(function_def->signature().name(), TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
std::move(captures), function_def, std::move(captures), ctx, &func));
std::move(metadata), ctx));
out->reset(new TFSignatureDefFunction(std::move(func), std::move(metadata)));
return Status(); return Status();
} }
@ -69,28 +58,7 @@ TFSignatureDefFunction::GetFunctionMetadata() const {
Status TFSignatureDefFunction::MakeCallOp( Status TFSignatureDefFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const { absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
out->reset(ctx_->CreateOperation()); return func_->MakeCallOp(inputs, out);
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
@ -67,22 +67,17 @@ class TFSignatureDefFunction : public SignatureDefFunction {
const SignatureDefFunctionMetadata& GetFunctionMetadata() const override; const SignatureDefFunctionMetadata& GetFunctionMetadata() const override;
~TFSignatureDefFunction() override; ~TFSignatureDefFunction() override = default;
private: private:
TFSignatureDefFunction(const std::string& name, TFSignatureDefFunction(std::unique_ptr<FlatTensorFunction> func,
std::vector<ImmediateExecutionTensorHandle*> captures, SignatureDefFunctionMetadata metadata);
SignatureDefFunctionMetadata metadata,
ImmediateExecutionContext* ctx);
TFSignatureDefFunction(const TFSignatureDefFunction&) = delete; TFSignatureDefFunction(const TFSignatureDefFunction&) = delete;
TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete; TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete;
// Name of the FunctionDef corresponding to this TFSignatureDefFunction std::unique_ptr<FlatTensorFunction> func_;
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
SignatureDefFunctionMetadata metadata_; SignatureDefFunctionMetadata metadata_;
ImmediateExecutionContext* ctx_;
}; };
} // namespace tensorflow } // namespace tensorflow