Refactor the guts of TFConcreteFunction into FlatTensorFunction, which can be shared between TFConcreteFunction and SignatureDefFunction. FlatTensorFunction handles wrapping a FunctionDef, it's name, any captures, and Call op creation. TFConcreteFunction and SignatureDefFunction can present different high level apis for metadata and function execution based on this lower level primitive.

PiperOrigin-RevId: 328683652
Change-Id: I91a49cf90b1c78fe2156a3c9fecd0250c03d5f30
This commit is contained in:
Brian Zhao 2020-08-26 23:41:40 -07:00 committed by TensorFlower Gardener
parent ecc6fe9066
commit 8674a3ce83
3 changed files with 189 additions and 0 deletions

View File

@ -28,6 +28,26 @@ cc_library(
],
)
cc_library(
name = "flat_tensor_function",
srcs = [
"flat_tensor_function.cc",
],
hdrs = [
"flat_tensor_function.h",
],
deps = [
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "variable",
srcs = [

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include <memory>
#include <string>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
FlatTensorFunction::FlatTensorFunction(
const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx)
: name_(name), captures_(std::move(captures)), ctx_(ctx) {}
FlatTensorFunction::~FlatTensorFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
Status FlatTensorFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
out->reset(new FlatTensorFunction(function_def->signature().name(),
std::move(captures), ctx));
return Status();
}
Status FlatTensorFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// FlatTensorFunction models a TF2 eager runtime view of a callable function,
// taking + returning flat lists of tensors, including any captures.
// Effectively, it is a thin wrapper around a FunctionDef owned by the
// EagerContext, and any TensorHandle captures associated with the function. The
// MakeCallOp method handles the logic of marshaling captures after the user
// provided inputs automatically.
// Note(bmzhao): This class is mainly intended to house low-level reusable
// function logic between SignatureDefFunction and ConcreteFunction, which
// present higher level interfaces. This type does *not* hold any "function
// metadata".
class FlatTensorFunction {
public:
// Factory for creating a FlatTensorFunction.
//
// Params:
// function_def - The function_def associated with the created
// FlatTensorFunction. FlatTensorFunction will register this
// function_def with `ctx` on creation, and de-register it on
// destruction. function_def must be non-null, but
// otherwise has no lifetime requirements.
// captures - The captured TensorHandles associated with this
// FlatTensorFunction.
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
// outlive TFConcreteFunction.
// out - The output FlatTensorFunction.
static Status Create(const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx,
std::unique_ptr<FlatTensorFunction>* out);
// This method creates a "Call" Op used to execute the function.
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const;
~FlatTensorFunction();
private:
FlatTensorFunction(const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx);
FlatTensorFunction(const FlatTensorFunction&) = delete;
FlatTensorFunction& operator=(const FlatTensorFunction&) = delete;
// Name of the FunctionDef corresponding to this TFConcreteFunction
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_