Refactoring TFConcreteFunction to depend on FlatTensorFunction for its implementation. (Also rename "GetCallOp" -> "MakeCallOp")
PiperOrigin-RevId: 328684297 Change-Id: I99d4c1a687057779769328e3503a9c0fb92d6978
This commit is contained in:
parent
8674a3ce83
commit
1ca6da51de
@ -43,8 +43,8 @@ class ConcreteFunction {
|
|||||||
virtual ~ConcreteFunction() = default;
|
virtual ~ConcreteFunction() = default;
|
||||||
|
|
||||||
// This method returns the "Call" Op used to execute the function.
|
// This method returns the "Call" Op used to execute the function.
|
||||||
virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
ImmediateOpPtr* out) = 0;
|
ImmediateOpPtr* out) const = 0;
|
||||||
|
|
||||||
virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
|
virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
|
||||||
};
|
};
|
||||||
|
@ -88,7 +88,7 @@ cc_library(
|
|||||||
"tf_concrete_function.h",
|
"tf_concrete_function.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":tensorhandle_convertible",
|
":flat_tensor_function",
|
||||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||||
"//tensorflow/c/eager:immediate_execution_context",
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
"//tensorflow/c/eager:immediate_execution_operation",
|
"//tensorflow/c/eager:immediate_execution_operation",
|
||||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
@ -33,32 +33,20 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
TFConcreteFunction::TFConcreteFunction(
|
TFConcreteFunction::TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
|
||||||
const std::string& name,
|
FunctionMetadata metadata)
|
||||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
: func_(std::move(func)), metadata_(std::move(metadata)) {}
|
||||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx)
|
|
||||||
: name_(name),
|
|
||||||
captures_(std::move(captures)),
|
|
||||||
metadata_(std::move(metadata)),
|
|
||||||
ctx_(ctx) {}
|
|
||||||
|
|
||||||
TFConcreteFunction::~TFConcreteFunction() {
|
|
||||||
Status status = ctx_->RemoveFunction(name_);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
|
|
||||||
<< status.error_message();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TFConcreteFunction::Create(
|
Status TFConcreteFunction::Create(
|
||||||
const FunctionDef* function_def,
|
const FunctionDef* function_def,
|
||||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx,
|
FunctionMetadata metadata, ImmediateExecutionContext* ctx,
|
||||||
std::unique_ptr<TFConcreteFunction>* out) {
|
std::unique_ptr<TFConcreteFunction>* out) {
|
||||||
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
|
std::unique_ptr<FlatTensorFunction> func;
|
||||||
out->reset(new TFConcreteFunction(function_def->signature().name(),
|
TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
|
||||||
std::move(captures), std::move(metadata),
|
function_def, std::move(captures), ctx, &func));
|
||||||
ctx));
|
|
||||||
|
out->reset(new TFConcreteFunction(std::move(func), std::move(metadata)));
|
||||||
return Status();
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
|
|||||||
return metadata_;
|
return metadata_;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TFConcreteFunction::GetCallOp(
|
Status TFConcreteFunction::MakeCallOp(
|
||||||
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) {
|
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
|
||||||
out->reset(ctx_->CreateOperation());
|
return func_->MakeCallOp(inputs, out);
|
||||||
// In eager mode, TF2 python executes functions by constructing an op with
|
|
||||||
// the name of the functiondef:
|
|
||||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
|
|
||||||
// In graph mode, we create a PartitionedCallOp instead:
|
|
||||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
|
|
||||||
|
|
||||||
// TODO(bmzhao): After discussing with Allen, we should execute this via a
|
|
||||||
// PartitionedCallOp for compatibility with "tooling that assumes functions in
|
|
||||||
// graphs are PartitionedCallOps".
|
|
||||||
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
|
|
||||||
|
|
||||||
// Adding the user-provided inputs to the function.
|
|
||||||
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
|
|
||||||
|
|
||||||
absl::Span<AbstractTensorHandle* const> captures(
|
|
||||||
reinterpret_cast<AbstractTensorHandle**>(captures_.data()),
|
|
||||||
captures_.size());
|
|
||||||
|
|
||||||
// Adding the captures of the function.
|
|
||||||
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
|
|
||||||
return Status();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
|
||||||
@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction {
|
|||||||
std::unique_ptr<TFConcreteFunction>* out);
|
std::unique_ptr<TFConcreteFunction>* out);
|
||||||
|
|
||||||
// This method returns the "Call" Op used to execute the function.
|
// This method returns the "Call" Op used to execute the function.
|
||||||
Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
ImmediateOpPtr* out) override;
|
ImmediateOpPtr* out) const override;
|
||||||
|
|
||||||
const FunctionMetadata& GetFunctionMetadata() const override;
|
const FunctionMetadata& GetFunctionMetadata() const override;
|
||||||
|
|
||||||
~TFConcreteFunction() override;
|
~TFConcreteFunction() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TFConcreteFunction(const std::string& name,
|
TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
|
||||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
FunctionMetadata metadata);
|
||||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx);
|
|
||||||
|
|
||||||
TFConcreteFunction(const TFConcreteFunction&) = delete;
|
TFConcreteFunction(const TFConcreteFunction&) = delete;
|
||||||
TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
|
TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
|
||||||
|
|
||||||
// Name of the FunctionDef corresponding to this TFConcreteFunction
|
std::unique_ptr<FlatTensorFunction> func_;
|
||||||
std::string name_;
|
|
||||||
std::vector<ImmediateExecutionTensorHandle*> captures_;
|
|
||||||
FunctionMetadata metadata_;
|
FunctionMetadata metadata_;
|
||||||
ImmediateExecutionContext* ctx_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -34,7 +34,7 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
|||||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
|
TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func,
|
||||||
TFE_TensorHandle** inputs, int num_inputs,
|
TFE_TensorHandle** inputs, int num_inputs,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::ImmediateOpPtr call_op;
|
tensorflow::ImmediateOpPtr call_op;
|
||||||
@ -42,7 +42,7 @@ TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
|
|||||||
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||||
tensorflow::unwrap(inputs)),
|
tensorflow::unwrap(inputs)),
|
||||||
static_cast<size_t>(num_inputs));
|
static_cast<size_t>(num_inputs));
|
||||||
status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op);
|
status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op);
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,7 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
|||||||
compute_fn_inputs.push_back(input_a);
|
compute_fn_inputs.push_back(input_a);
|
||||||
compute_fn_inputs.push_back(input_b);
|
compute_fn_inputs.push_back(input_b);
|
||||||
|
|
||||||
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(
|
TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp(
|
||||||
compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
|
compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
|
||||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
|||||||
// high-level API here. A strawman for what this interface could look like:
|
// high-level API here. A strawman for what this interface could look like:
|
||||||
// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value*
|
// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value*
|
||||||
// inputs, int num_inputs, TF_Status* status);
|
// inputs, int num_inputs, TF_Status* status);
|
||||||
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
|
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp(
|
||||||
TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
|
TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user