Refactor the guts of TFConcreteFunction into FlatTensorFunction, which can be shared between TFConcreteFunction and SignatureDefFunction. FlatTensorFunction handles wrapping a FunctionDef, it's name, any captures, and Call op creation. TFConcreteFunction and SignatureDefFunction can present different high level apis for metadata and function execution based on this lower level primitive.
PiperOrigin-RevId: 328683652 Change-Id: I91a49cf90b1c78fe2156a3c9fecd0250c03d5f30
This commit is contained in:
parent
ecc6fe9066
commit
8674a3ce83
@ -28,6 +28,26 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flat_tensor_function",
|
||||
srcs = [
|
||||
"flat_tensor_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"flat_tensor_function.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,85 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
FlatTensorFunction::FlatTensorFunction(
|
||||
const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx)
|
||||
: name_(name), captures_(std::move(captures)), ctx_(ctx) {}
|
||||
|
||||
FlatTensorFunction::~FlatTensorFunction() {
|
||||
Status status = ctx_->RemoveFunction(name_);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
|
||||
<< status.error_message();
|
||||
}
|
||||
}
|
||||
|
||||
Status FlatTensorFunction::Create(
|
||||
const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
|
||||
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
|
||||
out->reset(new FlatTensorFunction(function_def->signature().name(),
|
||||
std::move(captures), ctx));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status FlatTensorFunction::MakeCallOp(
|
||||
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
|
||||
out->reset(ctx_->CreateOperation());
|
||||
// In eager mode, TF2 python executes functions by constructing an op with
|
||||
// the name of the functiondef:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
|
||||
// In graph mode, we create a PartitionedCallOp instead:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
|
||||
|
||||
// TODO(bmzhao): After discussing with Allen, we should execute this via a
|
||||
// PartitionedCallOp for compatibility with "tooling that assumes functions in
|
||||
// graphs are PartitionedCallOps".
|
||||
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
|
||||
|
||||
// Adding the user-provided inputs to the function.
|
||||
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
|
||||
|
||||
absl::Span<AbstractTensorHandle* const> captures(
|
||||
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
|
||||
captures_.size());
|
||||
|
||||
// Adding the captures of the function.
|
||||
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// FlatTensorFunction models a TF2 eager runtime view of a callable function,
|
||||
// taking + returning flat lists of tensors, including any captures.
|
||||
// Effectively, it is a thin wrapper around a FunctionDef owned by the
|
||||
// EagerContext, and any TensorHandle captures associated with the function. The
|
||||
// MakeCallOp method handles the logic of marshaling captures after the user
|
||||
// provided inputs automatically.
|
||||
// Note(bmzhao): This class is mainly intended to house low-level reusable
|
||||
// function logic between SignatureDefFunction and ConcreteFunction, which
|
||||
// present higher level interfaces. This type does *not* hold any "function
|
||||
// metadata".
|
||||
class FlatTensorFunction {
|
||||
public:
|
||||
// Factory for creating a FlatTensorFunction.
|
||||
//
|
||||
// Params:
|
||||
// function_def - The function_def associated with the created
|
||||
// FlatTensorFunction. FlatTensorFunction will register this
|
||||
// function_def with `ctx` on creation, and de-register it on
|
||||
// destruction. function_def must be non-null, but
|
||||
// otherwise has no lifetime requirements.
|
||||
// captures - The captured TensorHandles associated with this
|
||||
// FlatTensorFunction.
|
||||
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
|
||||
// outlive TFConcreteFunction.
|
||||
// out - The output FlatTensorFunction.
|
||||
static Status Create(const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<FlatTensorFunction>* out);
|
||||
|
||||
// This method creates a "Call" Op used to execute the function.
|
||||
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ImmediateOpPtr* out) const;
|
||||
|
||||
~FlatTensorFunction();
|
||||
|
||||
private:
|
||||
FlatTensorFunction(const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
ImmediateExecutionContext* ctx);
|
||||
|
||||
FlatTensorFunction(const FlatTensorFunction&) = delete;
|
||||
FlatTensorFunction& operator=(const FlatTensorFunction&) = delete;
|
||||
|
||||
// Name of the FunctionDef corresponding to this TFConcreteFunction
|
||||
std::string name_;
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures_;
|
||||
ImmediateExecutionContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
|
Loading…
Reference in New Issue
Block a user