diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 934fa6d2bda..48a20ef7768 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -43,8 +43,8 @@ class ConcreteFunction { virtual ~ConcreteFunction() = default; // This method returns the "Call" Op used to execute the function. - virtual Status GetCallOp(absl::Span inputs, - ImmediateOpPtr* out) = 0; + virtual Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const = 0; virtual const FunctionMetadata& GetFunctionMetadata() const = 0; }; diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 3e826571e82..662f63e0133 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -88,7 +88,7 @@ cc_library( "tf_concrete_function.h", ], deps = [ - ":tensorhandle_convertible", + ":flat_tensor_function", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc index f734f9eca66..d9773a4520f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/errors.h" @@ -33,32 +33,20 @@ limitations under the License. namespace tensorflow { -TFConcreteFunction::TFConcreteFunction( - const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx) - : name_(name), - captures_(std::move(captures)), - metadata_(std::move(metadata)), - ctx_(ctx) {} - -TFConcreteFunction::~TFConcreteFunction() { - Status status = ctx_->RemoveFunction(name_); - if (!status.ok()) { - LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " - << status.error_message(); - } -} +TFConcreteFunction::TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata) + : func_(std::move(func)), metadata_(std::move(metadata)) {} Status TFConcreteFunction::Create( const FunctionDef* function_def, std::vector captures, FunctionMetadata metadata, ImmediateExecutionContext* ctx, std::unique_ptr* out) { - TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); - out->reset(new TFConcreteFunction(function_def->signature().name(), - std::move(captures), std::move(metadata), - ctx)); + std::unique_ptr func; + TF_RETURN_IF_ERROR(FlatTensorFunction::Create( + function_def, std::move(captures), ctx, &func)); + + out->reset(new TFConcreteFunction(std::move(func), std::move(metadata))); return Status(); } @@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { return metadata_; } -Status TFConcreteFunction::GetCallOp( - absl::Span inputs, ImmediateOpPtr* out) { - out->reset(ctx_->CreateOperation()); - // In eager mode, TF2 python executes functions by constructing an op with - // the name of the functiondef: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 - // In graph mode, we create a PartitionedCallOp instead: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 - - // TODO(bmzhao): After discussing with Allen, we should execute this via a - // PartitionedCallOp for compatibility with "tooling that assumes functions in - // graphs are PartitionedCallOps". - TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); - - // Adding the user-provided inputs to the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); - - absl::Span captures( - reinterpret_cast(captures_.data()), - captures_.size()); - - // Adding the captures of the function. - TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); - return Status(); +Status TFConcreteFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + return func_->MakeCallOp(inputs, out); } } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h index d38f3546f91..edc26f4d5aa 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" @@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction { std::unique_ptr* out); // This method returns the "Call" Op used to execute the function. - Status GetCallOp(absl::Span inputs, - ImmediateOpPtr* out) override; + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; const FunctionMetadata& GetFunctionMetadata() const override; - ~TFConcreteFunction() override; + ~TFConcreteFunction() override = default; private: - TFConcreteFunction(const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx); + TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata); TFConcreteFunction(const TFConcreteFunction&) = delete; TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; - // Name of the FunctionDef corresponding to this TFConcreteFunction - std::string name_; - std::vector captures_; + std::unique_ptr func_; FunctionMetadata metadata_; - ImmediateExecutionContext* ctx_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index 65c6eca5623..2beed8f4119 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -34,15 +34,15 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { &tensorflow::unwrap(func)->GetFunctionMetadata())); } -TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, - TFE_TensorHandle** inputs, int num_inputs, - TF_Status* status) { +TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func, + TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status) { tensorflow::ImmediateOpPtr call_op; absl::Span input_span( reinterpret_cast( tensorflow::unwrap(inputs)), static_cast(num_inputs)); - status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op); + status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op); if (!status->status.ok()) { return nullptr; } diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index e58b232f9c9..df998fcf6cd 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -107,7 +107,7 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) { compute_fn_inputs.push_back(input_a); compute_fn_inputs.push_back(input_b); - TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp( + TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp( compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status); EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 0fd0f70cf16..ff8a245961a 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -47,7 +47,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( // high-level API here. A strawman for what this interface could look like: // TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value* // inputs, int num_inputs, TF_Status* status); -TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( +TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp( TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status);