A very basic start on some op handler infrastructure
Does not include a handler's tensor representation (and so no copy-on etc.), and almost all of the hooks are missing. My medium-term goal is to get the parallel device working with function replay so TPU collectives work inside functions. That will also get us a replication primitive for use with the eager/graph agnostic C API, and I'll plan to call it from the existing custom device to start. PiperOrigin-RevId: 340253840 Change-Id: Ic9a5acca7bf42ceb9cb54aca635a9861daca3b38
This commit is contained in:
parent
aeafc6a567
commit
e8c62ab31d
@ -32,7 +32,7 @@ namespace tensorflow {
|
|||||||
// environment, a traced representation etc.
|
// environment, a traced representation etc.
|
||||||
class AbstractContext {
|
class AbstractContext {
|
||||||
protected:
|
protected:
|
||||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler };
|
||||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||||
virtual ~AbstractContext() {}
|
virtual ~AbstractContext() {}
|
||||||
|
|
||||||
|
@ -30,7 +30,14 @@ namespace tensorflow {
|
|||||||
// tracing or immediate execution mode.
|
// tracing or immediate execution mode.
|
||||||
class AbstractOperation {
|
class AbstractOperation {
|
||||||
protected:
|
protected:
|
||||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
enum AbstractOperationKind {
|
||||||
|
kGraph,
|
||||||
|
kMlir,
|
||||||
|
kEager,
|
||||||
|
kTfrt,
|
||||||
|
kTape,
|
||||||
|
kOpHandler
|
||||||
|
};
|
||||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||||
virtual ~AbstractOperation() {}
|
virtual ~AbstractOperation() {}
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
|
|||||||
parent_op_(parent_op),
|
parent_op_(parent_op),
|
||||||
tape_(tape),
|
tape_(tape),
|
||||||
registry_(registry) {
|
registry_(registry) {
|
||||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||||
// parent_op_->Ref();
|
// parent_op_->Ref();
|
||||||
}
|
}
|
||||||
void TapeOperation::Release() {
|
void TapeOperation::Release() {
|
||||||
@ -33,7 +33,7 @@ void TapeOperation::Release() {
|
|||||||
delete this;
|
delete this;
|
||||||
}
|
}
|
||||||
TapeOperation::~TapeOperation() {
|
TapeOperation::~TapeOperation() {
|
||||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||||
// parent_op->Unref();
|
// parent_op->Unref();
|
||||||
}
|
}
|
||||||
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
|
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
|
||||||
|
43
tensorflow/c/experimental/op_handler/BUILD
Normal file
43
tensorflow/c/experimental/op_handler/BUILD
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
|
|
||||||
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "internal_test",
|
||||||
|
srcs = ["internal_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":internal",
|
||||||
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
|
"//tensorflow/c/eager:c_api_unified_internal",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "internal",
|
||||||
|
srcs = ["internal.cc"],
|
||||||
|
hdrs = ["internal.h"],
|
||||||
|
deps = [
|
||||||
|
":wrapper_operation",
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
|
"//tensorflow/c/eager:abstract_context",
|
||||||
|
"//tensorflow/c/eager:abstract_operation",
|
||||||
|
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||||
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
|
"//tensorflow/core/platform:refcount",
|
||||||
|
"//tensorflow/core/platform:types",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "wrapper_operation",
|
||||||
|
srcs = ["wrapper_operation.cc"],
|
||||||
|
hdrs = ["wrapper_operation.h"],
|
||||||
|
deps = ["//tensorflow/c/eager:abstract_operation"],
|
||||||
|
)
|
79
tensorflow/c/experimental/op_handler/internal.cc
Normal file
79
tensorflow/c/experimental/op_handler/internal.cc
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/op_handler/internal.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_context.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_operation.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
OpHandlerContext::OpHandlerContext(AbstractContext* parent_ctx)
|
||||||
|
: AbstractContext(kOpHandler), parent_ctx_(parent_ctx) {}
|
||||||
|
OpHandlerContext::~OpHandlerContext() {}
|
||||||
|
void OpHandlerContext::Release() { delete this; }
|
||||||
|
Status OpHandlerContext::RegisterFunction(AbstractFunction* function) {
|
||||||
|
return parent_ctx_->RegisterFunction(function);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OpHandlerContext::RemoveFunction(const string& function) {
|
||||||
|
return parent_ctx_->RemoveFunction(function);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpHandlerContext::set_default_handler(OpHandler* handler) {
|
||||||
|
handler->Ref();
|
||||||
|
default_handler_.reset(handler);
|
||||||
|
}
|
||||||
|
|
||||||
|
OpHandlerOperation* OpHandlerContext::CreateOperation() {
|
||||||
|
OpHandlerOperation* result =
|
||||||
|
new OpHandlerOperation(parent_ctx_->CreateOperation());
|
||||||
|
if (default_handler_ != nullptr) {
|
||||||
|
result->set_handler(default_handler_.get());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpHandlerOperation::OpHandlerOperation(AbstractOperation* parent_op)
|
||||||
|
: WrapperOperation(parent_op, kOpHandler) {}
|
||||||
|
|
||||||
|
OpHandler* OpHandlerOperation::get_handler() { return handler_.get(); }
|
||||||
|
|
||||||
|
void OpHandlerOperation::set_handler(OpHandler* handler) {
|
||||||
|
if (handler != nullptr) {
|
||||||
|
handler->Ref();
|
||||||
|
}
|
||||||
|
handler_.reset(handler);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OpHandlerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
|
int* num_retvals) {
|
||||||
|
if (handler_ == nullptr) {
|
||||||
|
return WrapperOperation::Execute(retvals, num_retvals);
|
||||||
|
} else {
|
||||||
|
return handler_->Execute(this, retvals, num_retvals);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
117
tensorflow/c/experimental/op_handler/internal.h
Normal file
117
tensorflow/c/experimental/op_handler/internal.h
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_context.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_operation.h"
|
||||||
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||||
|
#include "tensorflow/core/platform/refcount.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class OpHandlerOperation;
|
||||||
|
|
||||||
|
// Op handlers are a convenient way to intercept and transform computation.
|
||||||
|
//
|
||||||
|
// The implementation is currently experimental and incomplete, but aims
|
||||||
|
// eventually to support tracing and replay of function bodies, gradients
|
||||||
|
// through copy operations, and a variety of hooks for things like debug
|
||||||
|
// strings. A public C API for op handlers is planned.
|
||||||
|
class OpHandler : public core::RefCounted {
|
||||||
|
public:
|
||||||
|
// Called on operation->Execute when operation->get_handler() == this.
|
||||||
|
//
|
||||||
|
// Allows the handler to customize or inspect `operation`'s execution.
|
||||||
|
virtual Status Execute(OpHandlerOperation* operation,
|
||||||
|
absl::Span<AbstractTensorHandle*> retvals,
|
||||||
|
int* num_retvals) = 0;
|
||||||
|
// Creates a new handler by merging this handler with `next_handler`.
|
||||||
|
//
|
||||||
|
// The new handler is expected to transform operations first with this handler
|
||||||
|
// and then execute the resulting operations on `next_handler` (by calling
|
||||||
|
// `OpHandlerOperation::set_handler` and passing `next_handler`). If this is
|
||||||
|
// not possible then the merge operation should fail.
|
||||||
|
virtual Status Merge(OpHandler* next_handler,
|
||||||
|
core::RefCountPtr<OpHandler>& merged_handler) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Keeps some handler-specific metadata, but otherwise wraps a single
|
||||||
|
// AbstractOperation in the underlying context. The operation is created, its
|
||||||
|
// attributes set, etc., and at execution time it is presented to its handler,
|
||||||
|
// which may choose to execute it or simply inspect it and do something else.
|
||||||
|
//
|
||||||
|
// This is somewhat different than the Context approach, where the operation's
|
||||||
|
// construction is streamed through each layered Context. The streaming approach
|
||||||
|
// would require a much larger op handler public API, one function pointer per
|
||||||
|
// attribute type, and there is some ambiguity before an op is finalized about
|
||||||
|
// whether it should be presented as-is to handlers (regular operations) or
|
||||||
|
// replayed (function calls and control flow operations).
|
||||||
|
class OpHandlerOperation : public WrapperOperation {
|
||||||
|
public:
|
||||||
|
explicit OpHandlerOperation(AbstractOperation*);
|
||||||
|
OpHandler* get_handler();
|
||||||
|
void set_handler(OpHandler* handler);
|
||||||
|
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
|
int* num_retvals) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
core::RefCountPtr<OpHandler> handler_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A context which allows a default handler to be set for new operations. It
|
||||||
|
// otherwise defers to the context it wraps.
|
||||||
|
//
|
||||||
|
// TODO(allenl): A stack of contexts and a stack of handlers look pretty similar
|
||||||
|
// in some ways. Having each handler be its own context seems almost doable,
|
||||||
|
// with things like copy operations and function/control flow replay being
|
||||||
|
// somewhat tricky (since they should be generated at the top of the handler
|
||||||
|
// stack and "caught" at the bottom). After handlers have evolved for a bit we
|
||||||
|
// should re-evaluate whether the handler+context concepts can be merged.
|
||||||
|
class OpHandlerContext : public AbstractContext {
|
||||||
|
public:
|
||||||
|
explicit OpHandlerContext(AbstractContext*);
|
||||||
|
void Release() override;
|
||||||
|
OpHandlerOperation* CreateOperation() override;
|
||||||
|
Status RegisterFunction(AbstractFunction*) override;
|
||||||
|
Status RemoveFunction(const string&) override;
|
||||||
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractContext* ptr) {
|
||||||
|
return ptr->getKind() == kOpHandler;
|
||||||
|
}
|
||||||
|
~OpHandlerContext() override;
|
||||||
|
|
||||||
|
void set_default_handler(OpHandler* handler);
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractContext* parent_ctx_; // Not owned.
|
||||||
|
core::RefCountPtr<OpHandler> default_handler_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReleaseOpHandlerOperation {
|
||||||
|
public:
|
||||||
|
void operator()(OpHandlerOperation* operation) { operation->Release(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::unique_ptr<OpHandlerOperation, ReleaseOpHandlerOperation>
|
||||||
|
OpHandlerOperationPtr;
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
102
tensorflow/c/experimental/op_handler/internal_test.cc
Normal file
102
tensorflow/c/experimental/op_handler/internal_test.cc
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/op_handler/internal.h"
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class TestOpHandler : public OpHandler {
|
||||||
|
public:
|
||||||
|
TestOpHandler() : last_operation_(new std::string("")) {}
|
||||||
|
Status Execute(OpHandlerOperation* operation,
|
||||||
|
absl::Span<AbstractTensorHandle*> retvals,
|
||||||
|
int* num_retvals) override {
|
||||||
|
CHECK(operation->get_handler() == this);
|
||||||
|
*last_operation_ = operation->Name();
|
||||||
|
operation->set_handler(next_handler_.get());
|
||||||
|
return operation->Execute(retvals, num_retvals);
|
||||||
|
}
|
||||||
|
Status Merge(OpHandler* next_handler,
|
||||||
|
core::RefCountPtr<OpHandler>& merged_handler) override {
|
||||||
|
merged_handler.reset(new TestOpHandler(next_handler, last_operation_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
core::RefCountPtr<OpHandler> next_handler_ = nullptr;
|
||||||
|
// Shared between merged handlers of this type.
|
||||||
|
std::shared_ptr<std::string> last_operation_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TestOpHandler(OpHandler* next_handler,
|
||||||
|
std::shared_ptr<std::string> last_operation)
|
||||||
|
: next_handler_(next_handler), last_operation_(last_operation) {
|
||||||
|
next_handler->Ref();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(INTERNAL_TEST, UseOpHandler) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||||
|
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||||
|
std::unique_ptr<TF_ExecutionContext, decltype(&TF_DeleteExecutionContext)>
|
||||||
|
c_ctx(TF_NewEagerExecutionContext(opts.get(), status.get()),
|
||||||
|
TF_DeleteExecutionContext);
|
||||||
|
OpHandlerContext ctx(unwrap(c_ctx.get()));
|
||||||
|
core::RefCountPtr<TestOpHandler> outer_handler(new TestOpHandler());
|
||||||
|
core::RefCountPtr<TestOpHandler> inner_handler(new TestOpHandler());
|
||||||
|
ctx.set_default_handler(outer_handler.get());
|
||||||
|
OpHandlerOperationPtr op(ctx.CreateOperation());
|
||||||
|
Status s = op->Reset("NoOp", "");
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
std::vector<AbstractTensorHandle*> retvals;
|
||||||
|
int num_retvals = 0;
|
||||||
|
EXPECT_EQ("", *outer_handler->last_operation_);
|
||||||
|
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
|
||||||
|
*outer_handler->last_operation_ = "";
|
||||||
|
EXPECT_EQ("", *inner_handler->last_operation_);
|
||||||
|
|
||||||
|
// This op executes on both handlers, changing the state of `inner_handler`
|
||||||
|
// since the handler has decided to preserve that state across merges.
|
||||||
|
core::RefCountPtr<OpHandler> merged;
|
||||||
|
s = inner_handler->Merge(outer_handler.get(), merged);
|
||||||
|
ctx.set_default_handler(merged.get());
|
||||||
|
op.reset(ctx.CreateOperation());
|
||||||
|
s = op->Reset("NoOp", "");
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
EXPECT_EQ("NoOp", *inner_handler->last_operation_);
|
||||||
|
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
|
||||||
|
|
||||||
|
inner_handler.reset();
|
||||||
|
outer_handler.reset();
|
||||||
|
op.reset(ctx.CreateOperation());
|
||||||
|
s = op->Reset("NoOp", "");
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
120
tensorflow/c/experimental/op_handler/wrapper_operation.cc
Normal file
120
tensorflow/c/experimental/op_handler/wrapper_operation.cc
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
WrapperOperation::WrapperOperation(AbstractOperation* parent_op,
|
||||||
|
AbstractOperationKind kind)
|
||||||
|
: AbstractOperation(kind), parent_op_(parent_op) {
|
||||||
|
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||||
|
// parent_op_->Ref();
|
||||||
|
}
|
||||||
|
void WrapperOperation::Release() {
|
||||||
|
parent_op_->Release();
|
||||||
|
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||||
|
delete this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status WrapperOperation::Reset(const char* op, const char* raw_device_name) {
|
||||||
|
return parent_op_->Reset(op, raw_device_name);
|
||||||
|
}
|
||||||
|
const string& WrapperOperation::Name() const { return parent_op_->Name(); }
|
||||||
|
const string& WrapperOperation::DeviceName() const {
|
||||||
|
return parent_op_->DeviceName();
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetDeviceName(const char* name) {
|
||||||
|
return parent_op_->SetDeviceName(name);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::AddInput(AbstractTensorHandle* input) {
|
||||||
|
return parent_op_->AddInput(input);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::AddInputList(
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||||
|
return parent_op_->AddInputList(inputs);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrString(const char* attr_name, const char* data,
|
||||||
|
size_t length) {
|
||||||
|
return parent_op_->SetAttrString(attr_name, data, length);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrInt(const char* attr_name, int64_t value) {
|
||||||
|
return parent_op_->SetAttrInt(attr_name, value);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrFloat(const char* attr_name, float value) {
|
||||||
|
return parent_op_->SetAttrFloat(attr_name, value);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrBool(const char* attr_name, bool value) {
|
||||||
|
return parent_op_->SetAttrBool(attr_name, value);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrType(const char* attr_name, DataType value) {
|
||||||
|
return parent_op_->SetAttrType(attr_name, value);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrShape(const char* attr_name,
|
||||||
|
const int64_t* dims, const int num_dims) {
|
||||||
|
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrFunction(const char* attr_name,
|
||||||
|
const AbstractOperation* value) {
|
||||||
|
return parent_op_->SetAttrFunction(attr_name, value);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrFunctionName(const char* attr_name,
|
||||||
|
const char* value, size_t length) {
|
||||||
|
return parent_op_->SetAttrFunctionName(attr_name, value, length);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrTensor(const char* attr_name,
|
||||||
|
AbstractTensorInterface* tensor) {
|
||||||
|
return parent_op_->SetAttrTensor(attr_name, tensor);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrStringList(const char* attr_name,
|
||||||
|
const void* const* values,
|
||||||
|
const size_t* lengths,
|
||||||
|
int num_values) {
|
||||||
|
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrFloatList(const char* attr_name,
|
||||||
|
const float* values, int num_values) {
|
||||||
|
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrIntList(const char* attr_name,
|
||||||
|
const int64_t* values, int num_values) {
|
||||||
|
return parent_op_->SetAttrIntList(attr_name, values, num_values);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrTypeList(const char* attr_name,
|
||||||
|
const DataType* values,
|
||||||
|
int num_values) {
|
||||||
|
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrBoolList(const char* attr_name,
|
||||||
|
const unsigned char* values,
|
||||||
|
int num_values) {
|
||||||
|
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrShapeList(const char* attr_name,
|
||||||
|
const int64_t** dims,
|
||||||
|
const int* num_dims, int num_values) {
|
||||||
|
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||||
|
}
|
||||||
|
Status WrapperOperation::SetAttrFunctionList(
|
||||||
|
const char* attr_name, absl::Span<const AbstractOperation*> values) {
|
||||||
|
return parent_op_->SetAttrFunctionList(attr_name, values);
|
||||||
|
}
|
||||||
|
AbstractOperation* WrapperOperation::GetBackingOperation() {
|
||||||
|
return parent_op_;
|
||||||
|
}
|
||||||
|
Status WrapperOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
|
int* num_retvals) {
|
||||||
|
return parent_op_->Execute(retvals, num_retvals);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
74
tensorflow/c/experimental/op_handler/wrapper_operation.h
Normal file
74
tensorflow/c/experimental/op_handler/wrapper_operation.h
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/abstract_operation.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Forwards all of the AbstractOperation's methods to its wrapped operation.
|
||||||
|
//
|
||||||
|
// Useful as a base class to default to forwarding while adding some
|
||||||
|
// customization.
|
||||||
|
class WrapperOperation : public AbstractOperation {
|
||||||
|
public:
|
||||||
|
explicit WrapperOperation(AbstractOperation*, AbstractOperationKind kind);
|
||||||
|
void Release() override;
|
||||||
|
Status Reset(const char* op, const char* raw_device_name) override;
|
||||||
|
const string& Name() const override;
|
||||||
|
const string& DeviceName() const override;
|
||||||
|
Status SetDeviceName(const char* name) override;
|
||||||
|
Status AddInput(AbstractTensorHandle* input) override;
|
||||||
|
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||||
|
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
|
int* num_retvals) override;
|
||||||
|
Status SetAttrString(const char* attr_name, const char* data,
|
||||||
|
size_t length) override;
|
||||||
|
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||||
|
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||||
|
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||||
|
Status SetAttrType(const char* attr_name, DataType value) override;
|
||||||
|
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||||
|
const int num_dims) override;
|
||||||
|
Status SetAttrFunction(const char* attr_name,
|
||||||
|
const AbstractOperation* value) override;
|
||||||
|
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||||
|
size_t length) override;
|
||||||
|
Status SetAttrTensor(const char* attr_name,
|
||||||
|
AbstractTensorInterface* tensor) override;
|
||||||
|
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||||
|
const size_t* lengths, int num_values) override;
|
||||||
|
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||||
|
int num_values) override;
|
||||||
|
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||||
|
const int* num_dims, int num_values) override;
|
||||||
|
Status SetAttrFunctionList(
|
||||||
|
const char* attr_name,
|
||||||
|
absl::Span<const AbstractOperation*> values) override;
|
||||||
|
AbstractOperation* GetBackingOperation();
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractOperation* parent_op_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
Loading…
Reference in New Issue
Block a user