A very basic start on some op handler infrastructure
Does not include a handler's tensor representation (and so no copy-on etc.), and almost all of the hooks are missing. My medium-term goal is to get the parallel device working with function replay so TPU collectives work inside functions. That will also get us a replication primitive for use with the eager/graph agnostic C API, and I'll plan to call it from the existing custom device to start. PiperOrigin-RevId: 340253840 Change-Id: Ic9a5acca7bf42ceb9cb54aca635a9861daca3b38
This commit is contained in:
parent
aeafc6a567
commit
e8c62ab31d
tensorflow/c
eager
experimental
@ -32,7 +32,7 @@ namespace tensorflow {
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
|
@ -30,7 +30,14 @@ namespace tensorflow {
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
enum AbstractOperationKind {
|
||||
kGraph,
|
||||
kMlir,
|
||||
kEager,
|
||||
kTfrt,
|
||||
kTape,
|
||||
kOpHandler
|
||||
};
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
|
@ -25,7 +25,7 @@ TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
|
||||
parent_op_(parent_op),
|
||||
tape_(tape),
|
||||
registry_(registry) {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||
// parent_op_->Ref();
|
||||
}
|
||||
void TapeOperation::Release() {
|
||||
@ -33,7 +33,7 @@ void TapeOperation::Release() {
|
||||
delete this;
|
||||
}
|
||||
TapeOperation::~TapeOperation() {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||
// parent_op->Unref();
|
||||
}
|
||||
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
|
||||
|
43
tensorflow/c/experimental/op_handler/BUILD
Normal file
43
tensorflow/c/experimental/op_handler/BUILD
Normal file
@ -0,0 +1,43 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "internal_test",
|
||||
srcs = ["internal_test.cc"],
|
||||
deps = [
|
||||
":internal",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "internal",
|
||||
srcs = ["internal.cc"],
|
||||
hdrs = ["internal.h"],
|
||||
deps = [
|
||||
":wrapper_operation",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "wrapper_operation",
|
||||
srcs = ["wrapper_operation.cc"],
|
||||
hdrs = ["wrapper_operation.h"],
|
||||
deps = ["//tensorflow/c/eager:abstract_operation"],
|
||||
)
|
79
tensorflow/c/experimental/op_handler/internal.cc
Normal file
79
tensorflow/c/experimental/op_handler/internal.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
|
||||
|
||||
#include "tensorflow/c/experimental/op_handler/internal.h"
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
OpHandlerContext::OpHandlerContext(AbstractContext* parent_ctx)
|
||||
: AbstractContext(kOpHandler), parent_ctx_(parent_ctx) {}
|
||||
OpHandlerContext::~OpHandlerContext() {}
|
||||
void OpHandlerContext::Release() { delete this; }
|
||||
Status OpHandlerContext::RegisterFunction(AbstractFunction* function) {
|
||||
return parent_ctx_->RegisterFunction(function);
|
||||
}
|
||||
|
||||
Status OpHandlerContext::RemoveFunction(const string& function) {
|
||||
return parent_ctx_->RemoveFunction(function);
|
||||
}
|
||||
|
||||
void OpHandlerContext::set_default_handler(OpHandler* handler) {
|
||||
handler->Ref();
|
||||
default_handler_.reset(handler);
|
||||
}
|
||||
|
||||
OpHandlerOperation* OpHandlerContext::CreateOperation() {
|
||||
OpHandlerOperation* result =
|
||||
new OpHandlerOperation(parent_ctx_->CreateOperation());
|
||||
if (default_handler_ != nullptr) {
|
||||
result->set_handler(default_handler_.get());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
OpHandlerOperation::OpHandlerOperation(AbstractOperation* parent_op)
|
||||
: WrapperOperation(parent_op, kOpHandler) {}
|
||||
|
||||
OpHandler* OpHandlerOperation::get_handler() { return handler_.get(); }
|
||||
|
||||
void OpHandlerOperation::set_handler(OpHandler* handler) {
|
||||
if (handler != nullptr) {
|
||||
handler->Ref();
|
||||
}
|
||||
handler_.reset(handler);
|
||||
}
|
||||
|
||||
Status OpHandlerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
if (handler_ == nullptr) {
|
||||
return WrapperOperation::Execute(retvals, num_retvals);
|
||||
} else {
|
||||
return handler_->Execute(this, retvals, num_retvals);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
117
tensorflow/c/experimental/op_handler/internal.h
Normal file
117
tensorflow/c/experimental/op_handler/internal.h
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpHandlerOperation;
|
||||
|
||||
// Op handlers are a convenient way to intercept and transform computation.
|
||||
//
|
||||
// The implementation is currently experimental and incomplete, but aims
|
||||
// eventually to support tracing and replay of function bodies, gradients
|
||||
// through copy operations, and a variety of hooks for things like debug
|
||||
// strings. A public C API for op handlers is planned.
|
||||
class OpHandler : public core::RefCounted {
|
||||
public:
|
||||
// Called on operation->Execute when operation->get_handler() == this.
|
||||
//
|
||||
// Allows the handler to customize or inspect `operation`'s execution.
|
||||
virtual Status Execute(OpHandlerOperation* operation,
|
||||
absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) = 0;
|
||||
// Creates a new handler by merging this handler with `next_handler`.
|
||||
//
|
||||
// The new handler is expected to transform operations first with this handler
|
||||
// and then execute the resulting operations on `next_handler` (by calling
|
||||
// `OpHandlerOperation::set_handler` and passing `next_handler`). If this is
|
||||
// not possible then the merge operation should fail.
|
||||
virtual Status Merge(OpHandler* next_handler,
|
||||
core::RefCountPtr<OpHandler>& merged_handler) = 0;
|
||||
};
|
||||
|
||||
// Keeps some handler-specific metadata, but otherwise wraps a single
|
||||
// AbstractOperation in the underlying context. The operation is created, its
|
||||
// attributes set, etc., and at execution time it is presented to its handler,
|
||||
// which may choose to execute it or simply inspect it and do something else.
|
||||
//
|
||||
// This is somewhat different than the Context approach, where the operation's
|
||||
// construction is streamed through each layered Context. The streaming approach
|
||||
// would require a much larger op handler public API, one function pointer per
|
||||
// attribute type, and there is some ambiguity before an op is finalized about
|
||||
// whether it should be presented as-is to handlers (regular operations) or
|
||||
// replayed (function calls and control flow operations).
|
||||
class OpHandlerOperation : public WrapperOperation {
|
||||
public:
|
||||
explicit OpHandlerOperation(AbstractOperation*);
|
||||
OpHandler* get_handler();
|
||||
void set_handler(OpHandler* handler);
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
|
||||
protected:
|
||||
core::RefCountPtr<OpHandler> handler_;
|
||||
};
|
||||
|
||||
// A context which allows a default handler to be set for new operations. It
|
||||
// otherwise defers to the context it wraps.
|
||||
//
|
||||
// TODO(allenl): A stack of contexts and a stack of handlers look pretty similar
|
||||
// in some ways. Having each handler be its own context seems almost doable,
|
||||
// with things like copy operations and function/control flow replay being
|
||||
// somewhat tricky (since they should be generated at the top of the handler
|
||||
// stack and "caught" at the bottom). After handlers have evolved for a bit we
|
||||
// should re-evaluate whether the handler+context concepts can be merged.
|
||||
class OpHandlerContext : public AbstractContext {
|
||||
public:
|
||||
explicit OpHandlerContext(AbstractContext*);
|
||||
void Release() override;
|
||||
OpHandlerOperation* CreateOperation() override;
|
||||
Status RegisterFunction(AbstractFunction*) override;
|
||||
Status RemoveFunction(const string&) override;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kOpHandler;
|
||||
}
|
||||
~OpHandlerContext() override;
|
||||
|
||||
void set_default_handler(OpHandler* handler);
|
||||
|
||||
private:
|
||||
AbstractContext* parent_ctx_; // Not owned.
|
||||
core::RefCountPtr<OpHandler> default_handler_;
|
||||
};
|
||||
|
||||
class ReleaseOpHandlerOperation {
|
||||
public:
|
||||
void operator()(OpHandlerOperation* operation) { operation->Release(); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<OpHandlerOperation, ReleaseOpHandlerOperation>
|
||||
OpHandlerOperationPtr;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
|
102
tensorflow/c/experimental/op_handler/internal_test.cc
Normal file
102
tensorflow/c/experimental/op_handler/internal_test.cc
Normal file
@ -0,0 +1,102 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/op_handler/internal.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TestOpHandler : public OpHandler {
|
||||
public:
|
||||
TestOpHandler() : last_operation_(new std::string("")) {}
|
||||
Status Execute(OpHandlerOperation* operation,
|
||||
absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override {
|
||||
CHECK(operation->get_handler() == this);
|
||||
*last_operation_ = operation->Name();
|
||||
operation->set_handler(next_handler_.get());
|
||||
return operation->Execute(retvals, num_retvals);
|
||||
}
|
||||
Status Merge(OpHandler* next_handler,
|
||||
core::RefCountPtr<OpHandler>& merged_handler) override {
|
||||
merged_handler.reset(new TestOpHandler(next_handler, last_operation_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
core::RefCountPtr<OpHandler> next_handler_ = nullptr;
|
||||
// Shared between merged handlers of this type.
|
||||
std::shared_ptr<std::string> last_operation_;
|
||||
|
||||
private:
|
||||
TestOpHandler(OpHandler* next_handler,
|
||||
std::shared_ptr<std::string> last_operation)
|
||||
: next_handler_(next_handler), last_operation_(last_operation) {
|
||||
next_handler->Ref();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(INTERNAL_TEST, UseOpHandler) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_ExecutionContext, decltype(&TF_DeleteExecutionContext)>
|
||||
c_ctx(TF_NewEagerExecutionContext(opts.get(), status.get()),
|
||||
TF_DeleteExecutionContext);
|
||||
OpHandlerContext ctx(unwrap(c_ctx.get()));
|
||||
core::RefCountPtr<TestOpHandler> outer_handler(new TestOpHandler());
|
||||
core::RefCountPtr<TestOpHandler> inner_handler(new TestOpHandler());
|
||||
ctx.set_default_handler(outer_handler.get());
|
||||
OpHandlerOperationPtr op(ctx.CreateOperation());
|
||||
Status s = op->Reset("NoOp", "");
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
std::vector<AbstractTensorHandle*> retvals;
|
||||
int num_retvals = 0;
|
||||
EXPECT_EQ("", *outer_handler->last_operation_);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
|
||||
*outer_handler->last_operation_ = "";
|
||||
EXPECT_EQ("", *inner_handler->last_operation_);
|
||||
|
||||
// This op executes on both handlers, changing the state of `inner_handler`
|
||||
// since the handler has decided to preserve that state across merges.
|
||||
core::RefCountPtr<OpHandler> merged;
|
||||
s = inner_handler->Merge(outer_handler.get(), merged);
|
||||
ctx.set_default_handler(merged.get());
|
||||
op.reset(ctx.CreateOperation());
|
||||
s = op->Reset("NoOp", "");
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
EXPECT_EQ("NoOp", *inner_handler->last_operation_);
|
||||
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
|
||||
|
||||
inner_handler.reset();
|
||||
outer_handler.reset();
|
||||
op.reset(ctx.CreateOperation());
|
||||
s = op->Reset("NoOp", "");
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
120
tensorflow/c/experimental/op_handler/wrapper_operation.cc
Normal file
120
tensorflow/c/experimental/op_handler/wrapper_operation.cc
Normal file
@ -0,0 +1,120 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
WrapperOperation::WrapperOperation(AbstractOperation* parent_op,
|
||||
AbstractOperationKind kind)
|
||||
: AbstractOperation(kind), parent_op_(parent_op) {
|
||||
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||
// parent_op_->Ref();
|
||||
}
|
||||
void WrapperOperation::Release() {
|
||||
parent_op_->Release();
|
||||
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
|
||||
delete this;
|
||||
}
|
||||
|
||||
Status WrapperOperation::Reset(const char* op, const char* raw_device_name) {
|
||||
return parent_op_->Reset(op, raw_device_name);
|
||||
}
|
||||
const string& WrapperOperation::Name() const { return parent_op_->Name(); }
|
||||
const string& WrapperOperation::DeviceName() const {
|
||||
return parent_op_->DeviceName();
|
||||
}
|
||||
Status WrapperOperation::SetDeviceName(const char* name) {
|
||||
return parent_op_->SetDeviceName(name);
|
||||
}
|
||||
Status WrapperOperation::AddInput(AbstractTensorHandle* input) {
|
||||
return parent_op_->AddInput(input);
|
||||
}
|
||||
Status WrapperOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||
return parent_op_->AddInputList(inputs);
|
||||
}
|
||||
Status WrapperOperation::SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) {
|
||||
return parent_op_->SetAttrString(attr_name, data, length);
|
||||
}
|
||||
Status WrapperOperation::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
return parent_op_->SetAttrInt(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFloat(const char* attr_name, float value) {
|
||||
return parent_op_->SetAttrFloat(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrBool(const char* attr_name, bool value) {
|
||||
return parent_op_->SetAttrBool(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrType(const char* attr_name, DataType value) {
|
||||
return parent_op_->SetAttrType(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims, const int num_dims) {
|
||||
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) {
|
||||
return parent_op_->SetAttrFunction(attr_name, value);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFunctionName(const char* attr_name,
|
||||
const char* value, size_t length) {
|
||||
return parent_op_->SetAttrFunctionName(attr_name, value, length);
|
||||
}
|
||||
Status WrapperOperation::SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) {
|
||||
return parent_op_->SetAttrTensor(attr_name, tensor);
|
||||
}
|
||||
Status WrapperOperation::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) {
|
||||
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFloatList(const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
return parent_op_->SetAttrIntList(attr_name, values, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrTypeList(const char* attr_name,
|
||||
const DataType* values,
|
||||
int num_values) {
|
||||
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims, int num_values) {
|
||||
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
Status WrapperOperation::SetAttrFunctionList(
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) {
|
||||
return parent_op_->SetAttrFunctionList(attr_name, values);
|
||||
}
|
||||
AbstractOperation* WrapperOperation::GetBackingOperation() {
|
||||
return parent_op_;
|
||||
}
|
||||
Status WrapperOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
return parent_op_->Execute(retvals, num_retvals);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
74
tensorflow/c/experimental/op_handler/wrapper_operation.h
Normal file
74
tensorflow/c/experimental/op_handler/wrapper_operation.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forwards all of the AbstractOperation's methods to its wrapped operation.
|
||||
//
|
||||
// Useful as a base class to default to forwarding while adding some
|
||||
// customization.
|
||||
class WrapperOperation : public AbstractOperation {
|
||||
public:
|
||||
explicit WrapperOperation(AbstractOperation*, AbstractOperationKind kind);
|
||||
void Release() override;
|
||||
Status Reset(const char* op, const char* raw_device_name) override;
|
||||
const string& Name() const override;
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values) override;
|
||||
AbstractOperation* GetBackingOperation();
|
||||
|
||||
private:
|
||||
AbstractOperation* parent_op_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
|
Loading…
Reference in New Issue
Block a user