Refactoring TFConcreteFunction to depend on FlatTensorFunction for its implementation. (Also rename "GetCallOp" -> "MakeCallOp")

PiperOrigin-RevId: 328684297
Change-Id: I99d4c1a687057779769328e3503a9c0fb92d6978
This commit is contained in:
Brian Zhao 2020-08-26 23:49:24 -07:00 committed by TensorFlower Gardener
parent 8674a3ce83
commit 1ca6da51de
7 changed files with 28 additions and 65 deletions

View File

@ -43,8 +43,8 @@ class ConcreteFunction {
virtual ~ConcreteFunction() = default; virtual ~ConcreteFunction() = default;
// This method returns the "Call" Op used to execute the function. // This method returns the "Call" Op used to execute the function.
virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs, virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) = 0; ImmediateOpPtr* out) const = 0;
virtual const FunctionMetadata& GetFunctionMetadata() const = 0; virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
}; };

View File

@ -88,7 +88,7 @@ cc_library(
"tf_concrete_function.h", "tf_concrete_function.h",
], ],
deps = [ deps = [
":tensorhandle_convertible", ":flat_tensor_function",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
@ -33,32 +33,20 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
TFConcreteFunction::TFConcreteFunction( TFConcreteFunction::TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
const std::string& name, FunctionMetadata metadata)
std::vector<ImmediateExecutionTensorHandle*> captures, : func_(std::move(func)), metadata_(std::move(metadata)) {}
FunctionMetadata metadata, ImmediateExecutionContext* ctx)
: name_(name),
captures_(std::move(captures)),
metadata_(std::move(metadata)),
ctx_(ctx) {}
TFConcreteFunction::~TFConcreteFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
Status TFConcreteFunction::Create( Status TFConcreteFunction::Create(
const FunctionDef* function_def, const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures, std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx, FunctionMetadata metadata, ImmediateExecutionContext* ctx,
std::unique_ptr<TFConcreteFunction>* out) { std::unique_ptr<TFConcreteFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); std::unique_ptr<FlatTensorFunction> func;
out->reset(new TFConcreteFunction(function_def->signature().name(), TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
std::move(captures), std::move(metadata), function_def, std::move(captures), ctx, &func));
ctx));
out->reset(new TFConcreteFunction(std::move(func), std::move(metadata)));
return Status(); return Status();
} }
@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
return metadata_; return metadata_;
} }
Status TFConcreteFunction::GetCallOp( Status TFConcreteFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) { absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
out->reset(ctx_->CreateOperation()); return func_->MakeCallOp(inputs, out);
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle**>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction {
std::unique_ptr<TFConcreteFunction>* out); std::unique_ptr<TFConcreteFunction>* out);
// This method returns the "Call" Op used to execute the function. // This method returns the "Call" Op used to execute the function.
Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs, Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) override; ImmediateOpPtr* out) const override;
const FunctionMetadata& GetFunctionMetadata() const override; const FunctionMetadata& GetFunctionMetadata() const override;
~TFConcreteFunction() override; ~TFConcreteFunction() override = default;
private: private:
TFConcreteFunction(const std::string& name, TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
std::vector<ImmediateExecutionTensorHandle*> captures, FunctionMetadata metadata);
FunctionMetadata metadata, ImmediateExecutionContext* ctx);
TFConcreteFunction(const TFConcreteFunction&) = delete; TFConcreteFunction(const TFConcreteFunction&) = delete;
TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
// Name of the FunctionDef corresponding to this TFConcreteFunction std::unique_ptr<FlatTensorFunction> func_;
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
FunctionMetadata metadata_; FunctionMetadata metadata_;
ImmediateExecutionContext* ctx_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -34,15 +34,15 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
&tensorflow::unwrap(func)->GetFunctionMetadata())); &tensorflow::unwrap(func)->GetFunctionMetadata()));
} }
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func,
TFE_TensorHandle** inputs, int num_inputs, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) { TF_Status* status) {
tensorflow::ImmediateOpPtr call_op; tensorflow::ImmediateOpPtr call_op;
absl::Span<tensorflow::AbstractTensorHandle* const> input_span( absl::Span<tensorflow::AbstractTensorHandle* const> input_span(
reinterpret_cast<tensorflow::AbstractTensorHandle**>( reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(inputs)), tensorflow::unwrap(inputs)),
static_cast<size_t>(num_inputs)); static_cast<size_t>(num_inputs));
status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op); status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }

View File

@ -107,7 +107,7 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
compute_fn_inputs.push_back(input_a); compute_fn_inputs.push_back(input_a);
compute_fn_inputs.push_back(input_b); compute_fn_inputs.push_back(input_b);
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp( TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp(
compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status); compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -47,7 +47,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
// high-level API here. A strawman for what this interface could look like: // high-level API here. A strawman for what this interface could look like:
// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value* // TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value*
// inputs, int num_inputs, TF_Status* status); // inputs, int num_inputs, TF_Status* status);
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp(
TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs, TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status); TF_Status* status);