diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 662f63e0133..25cac39daa0 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -111,7 +111,7 @@ cc_library( "tf_signature_def_function.h", ], deps = [ - ":tensorhandle_convertible", + ":flat_tensor_function", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc index 2a9f23b63cd..ab1745dcd47 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/errors.h" @@ -34,31 +34,20 @@ limitations under the License. namespace tensorflow { TFSignatureDefFunction::TFSignatureDefFunction( - const std::string& name, - std::vector captures, - SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx) - : name_(name), - captures_(std::move(captures)), - metadata_(std::move(metadata)), - ctx_(ctx) {} - -TFSignatureDefFunction::~TFSignatureDefFunction() { - Status status = ctx_->RemoveFunction(name_); - if (!status.ok()) { - LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " - << status.error_message(); - } -} + std::unique_ptr func, + SignatureDefFunctionMetadata metadata) + : func_(std::move(func)), metadata_(std::move(metadata)) {} Status TFSignatureDefFunction::Create( const FunctionDef* function_def, std::vector captures, SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx, std::unique_ptr* out) { - TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); - out->reset(new TFSignatureDefFunction(function_def->signature().name(), - std::move(captures), - std::move(metadata), ctx)); + std::unique_ptr func; + TF_RETURN_IF_ERROR(FlatTensorFunction::Create( + function_def, std::move(captures), ctx, &func)); + + out->reset(new TFSignatureDefFunction(std::move(func), std::move(metadata))); return Status(); } @@ -69,28 +58,7 @@ TFSignatureDefFunction::GetFunctionMetadata() const { Status TFSignatureDefFunction::MakeCallOp( absl::Span inputs, ImmediateOpPtr* out) const { - out->reset(ctx_->CreateOperation()); - // In eager mode, TF2 python executes functions by constructing an op with - // the name of the functiondef: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 - // In graph mode, we create a PartitionedCallOp instead: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 - - // TODO(bmzhao): After discussing with Allen, we should execute this via a - // PartitionedCallOp for compatibility with "tooling that assumes functions in - // graphs are PartitionedCallOps". - TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); - - // Adding the user-provided inputs to the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); - - absl::Span captures( - reinterpret_cast(captures_.data()), - captures_.size()); - - // Adding the captures of the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); - return Status(); + return func_->MakeCallOp(inputs, out); } } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h index 523560685c3..7b564185b8b 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" #include "tensorflow/core/framework/function.pb.h" @@ -67,22 +67,17 @@ class TFSignatureDefFunction : public SignatureDefFunction { const SignatureDefFunctionMetadata& GetFunctionMetadata() const override; - ~TFSignatureDefFunction() override; + ~TFSignatureDefFunction() override = default; private: - TFSignatureDefFunction(const std::string& name, - std::vector captures, - SignatureDefFunctionMetadata metadata, - ImmediateExecutionContext* ctx); + TFSignatureDefFunction(std::unique_ptr func, + SignatureDefFunctionMetadata metadata); TFSignatureDefFunction(const TFSignatureDefFunction&) = delete; TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete; - // Name of the FunctionDef corresponding to this TFSignatureDefFunction - std::string name_; - std::vector captures_; + std::unique_ptr func_; SignatureDefFunctionMetadata metadata_; - ImmediateExecutionContext* ctx_; }; } // namespace tensorflow