From 8674a3ce834b5858c9bcf446a584d95983f64244 Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Wed, 26 Aug 2020 23:41:40 -0700 Subject: [PATCH] Refactor the guts of TFConcreteFunction into FlatTensorFunction, which can be shared between TFConcreteFunction and SignatureDefFunction. FlatTensorFunction handles wrapping a FunctionDef, it's name, any captures, and Call op creation. TFConcreteFunction and SignatureDefFunction can present different high level apis for metadata and function execution based on this lower level primitive. PiperOrigin-RevId: 328683652 Change-Id: I91a49cf90b1c78fe2156a3c9fecd0250c03d5f30 --- .../saved_model/core/revived_types/BUILD | 20 +++++ .../revived_types/flat_tensor_function.cc | 85 +++++++++++++++++++ .../core/revived_types/flat_tensor_function.h | 84 ++++++++++++++++++ 3 files changed, 189 insertions(+) create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 21a8dae3d93..3e826571e82 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -28,6 +28,26 @@ cc_library( ], ) +cc_library( + name = "flat_tensor_function", + srcs = [ + "flat_tensor_function.cc", + ], + hdrs = [ + "flat_tensor_function.h", + ], + deps = [ + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "variable", srcs = [ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc new file mode 100644 index 00000000000..ad9f896f43d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +FlatTensorFunction::FlatTensorFunction( + const std::string& name, + std::vector captures, + ImmediateExecutionContext* ctx) + : name_(name), captures_(std::move(captures)), ctx_(ctx) {} + +FlatTensorFunction::~FlatTensorFunction() { + Status status = ctx_->RemoveFunction(name_); + if (!status.ok()) { + LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " + << status.error_message(); + } +} + +Status FlatTensorFunction::Create( + const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, std::unique_ptr* out) { + TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); + out->reset(new FlatTensorFunction(function_def->signature().name(), + std::move(captures), ctx)); + return Status(); +} + +Status FlatTensorFunction::MakeCallOp( + absl::Span inputs, ImmediateOpPtr* out) const { + out->reset(ctx_->CreateOperation()); + // In eager mode, TF2 python executes functions by constructing an op with + // the name of the functiondef: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 + // In graph mode, we create a PartitionedCallOp instead: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 + + // TODO(bmzhao): After discussing with Allen, we should execute this via a + // PartitionedCallOp for compatibility with "tooling that assumes functions in + // graphs are PartitionedCallOps". + TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); + + // Adding the user-provided inputs to the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); + + absl::Span captures( + reinterpret_cast(captures_.data()), + captures_.size()); + + // Adding the captures of the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h new file mode 100644 index 00000000000..e6bcdec7e3a --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FlatTensorFunction models a TF2 eager runtime view of a callable function, +// taking + returning flat lists of tensors, including any captures. +// Effectively, it is a thin wrapper around a FunctionDef owned by the +// EagerContext, and any TensorHandle captures associated with the function. The +// MakeCallOp method handles the logic of marshaling captures after the user +// provided inputs automatically. +// Note(bmzhao): This class is mainly intended to house low-level reusable +// function logic between SignatureDefFunction and ConcreteFunction, which +// present higher level interfaces. This type does *not* hold any "function +// metadata". +class FlatTensorFunction { + public: + // Factory for creating a FlatTensorFunction. + // + // Params: + // function_def - The function_def associated with the created + // FlatTensorFunction. FlatTensorFunction will register this + // function_def with `ctx` on creation, and de-register it on + // destruction. function_def must be non-null, but + // otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // FlatTensorFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFConcreteFunction. + // out - The output FlatTensorFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const; + + ~FlatTensorFunction(); + + private: + FlatTensorFunction(const std::string& name, + std::vector captures, + ImmediateExecutionContext* ctx); + + FlatTensorFunction(const FlatTensorFunction&) = delete; + FlatTensorFunction& operator=(const FlatTensorFunction&) = delete; + + // Name of the FunctionDef corresponding to this TFConcreteFunction + std::string name_; + std::vector captures_; + ImmediateExecutionContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_