Merge pull request #29754 from Intel-tensorflow:eager_op_rewrite_registration
PiperOrigin-RevId: 256443805
This commit is contained in:
commit
01abb79f05
@ -242,6 +242,7 @@ cc_library(
|
||||
":context",
|
||||
":copy_to_device_node",
|
||||
":eager_executor",
|
||||
":eager_op_rewrite_registry",
|
||||
":eager_operation",
|
||||
":kernel_and_device",
|
||||
":tensor_handle",
|
||||
@ -264,6 +265,24 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "eager_op_rewrite_registry",
|
||||
srcs = ["eager_op_rewrite_registry.cc"],
|
||||
hdrs = ["eager_op_rewrite_registry.h"],
|
||||
deps = [":eager_operation"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "eager_op_rewrite_registry_test",
|
||||
srcs = ["eager_op_rewrite_registry_test.cc"],
|
||||
deps = [
|
||||
":eager_op_rewrite_registry",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "attr_builder",
|
||||
srcs = ["attr_builder.cc"],
|
||||
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
EagerOpRewriteRegistry* EagerOpRewriteRegistry::Global() {
|
||||
static EagerOpRewriteRegistry* global_rewrite_registry =
|
||||
new EagerOpRewriteRegistry;
|
||||
return global_rewrite_registry;
|
||||
}
|
||||
|
||||
void EagerOpRewriteRegistry::Register(Phase phase,
|
||||
std::unique_ptr<EagerOpRewrite> pass) {
|
||||
if (rewrites_.find(phase) == rewrites_.end()) {
|
||||
rewrites_[phase] = std::move(pass);
|
||||
} else {
|
||||
TF_CHECK_OK(errors::AlreadyExists(pass->GetDebugInfo().name,
|
||||
" is already registered as"
|
||||
" EagerOpRewrite for this phase in ",
|
||||
pass->GetDebugInfo().file, ":",
|
||||
pass->GetDebugInfo().line));
|
||||
}
|
||||
}
|
||||
|
||||
Status EagerOpRewriteRegistry::RunRewrite(
|
||||
Phase phase, EagerOperation* orig_op,
|
||||
std::unique_ptr<tensorflow::EagerOperation>* out_op) {
|
||||
auto rewrite = rewrites_.find(phase);
|
||||
if (rewrite != rewrites_.end()) {
|
||||
Status s = rewrite->second->Run(orig_op, out_op);
|
||||
if (!s.ok()) return s;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
103
tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h
Normal file
103
tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h
Normal file
@ -0,0 +1,103 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Eager op rewrites should inherit from this class and
|
||||
// implement the Run method.
|
||||
class EagerOpRewrite {
|
||||
public:
|
||||
EagerOpRewrite(string name, string file, string line) {
|
||||
debug_info_.name = name;
|
||||
debug_info_.file = file;
|
||||
debug_info_.line = line;
|
||||
}
|
||||
|
||||
virtual ~EagerOpRewrite() {}
|
||||
|
||||
// To be implemnted by an Eager op rewrite pass.
|
||||
virtual Status Run(EagerOperation* orig_op,
|
||||
std::unique_ptr<tensorflow::EagerOperation>* out_op) = 0;
|
||||
|
||||
// Holds information about the rewrite registration.
|
||||
struct DebugInfo {
|
||||
string name, file, line;
|
||||
};
|
||||
|
||||
// Returns information about the registered Eager op rewrite.
|
||||
DebugInfo GetDebugInfo() const { return debug_info_; }
|
||||
|
||||
private:
|
||||
DebugInfo debug_info_;
|
||||
};
|
||||
|
||||
class EagerOpRewriteRegistry {
|
||||
public:
|
||||
// Phases at which the Eager op rewrite pass should run.
|
||||
// For now we only added PRE_EXECUTION. Expand as needed.
|
||||
enum Phase {
|
||||
PRE_EXECUTION // right before executing an eager op
|
||||
};
|
||||
|
||||
// Add a rewrite pass to the registry.
|
||||
// Only one rewrite pass is allowed per phase.
|
||||
void Register(Phase phase, std::unique_ptr<EagerOpRewrite> pass);
|
||||
|
||||
// Run the rewrite pass registered for a given phase.
|
||||
Status RunRewrite(Phase phase, EagerOperation* orig_op,
|
||||
std::unique_ptr<tensorflow::EagerOperation>* out_op);
|
||||
|
||||
// Returns the global registry of rewrite passes.
|
||||
static EagerOpRewriteRegistry* Global();
|
||||
|
||||
private:
|
||||
// Holds all the registered Eager op rewrites.
|
||||
std::map<Phase, std::unique_ptr<EagerOpRewrite>> rewrites_;
|
||||
};
|
||||
|
||||
namespace eager_rewrite_registration {
|
||||
|
||||
// This class is used to register a new Eager Op rewrite.
|
||||
class EagerRewriteRegistration {
|
||||
public:
|
||||
EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase,
|
||||
std::unique_ptr<EagerOpRewrite> pass) {
|
||||
EagerOpRewriteRegistry::Global()->Register(phase, std::move(pass));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace eager_rewrite_registration
|
||||
|
||||
#define REGISTER_REWRITE(phase, rewrite) \
|
||||
REGISTER_REWRITE_UNIQ_HELPER(__COUNTER__, __FILE__, __LINE__, phase, rewrite)
|
||||
|
||||
#define REGISTER_REWRITE_UNIQ_HELPER(ctr, file, line, phase, rewrite) \
|
||||
REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite)
|
||||
|
||||
#define REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite) \
|
||||
static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration \
|
||||
register_rewrite_##ctr(phase, \
|
||||
::std::unique_ptr<::tensorflow::EagerOpRewrite>( \
|
||||
new rewrite(#rewrite, file, #line)))
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TestEagerOpRewrite : public EagerOpRewrite {
|
||||
public:
|
||||
TestEagerOpRewrite(string name, string file, string line)
|
||||
: EagerOpRewrite(name, file, line) {}
|
||||
static int count_;
|
||||
Status Run(EagerOperation* orig_op,
|
||||
std::unique_ptr<tensorflow::EagerOperation>* out_op) override {
|
||||
++count_;
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
const string kNewOp = "NoOp";
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::AttrTypeMapForOp(kNewOp.c_str(), &types, &is_function));
|
||||
// Create a new NoOp Eager operation.
|
||||
out_op->reset(new tensorflow::EagerOperation(nullptr, kNewOp.c_str(),
|
||||
is_function, types));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
int TestEagerOpRewrite::count_ = 0;
|
||||
|
||||
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, TestEagerOpRewrite);
|
||||
|
||||
TEST(EagerOpRewriteRegistryTest, RegisterRewritePass) {
|
||||
EXPECT_EQ(0, TestEagerOpRewrite::count_);
|
||||
EagerOperation* orig_op = nullptr;
|
||||
std::unique_ptr<tensorflow::EagerOperation> out_op;
|
||||
EXPECT_EQ(Status::OK(),
|
||||
EagerOpRewriteRegistry::Global()->RunRewrite(
|
||||
EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, &out_op));
|
||||
EXPECT_EQ(1, TestEagerOpRewrite::count_);
|
||||
EXPECT_EQ("NoOp", out_op->Name());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -57,6 +57,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -1033,8 +1034,16 @@ Status EagerExecute(EagerOperation* op,
|
||||
|
||||
bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName());
|
||||
|
||||
std::unique_ptr<tensorflow::EagerOperation> out_op;
|
||||
TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
|
||||
EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
|
||||
|
||||
if (op_is_local) {
|
||||
return EagerLocalExecute(op, retvals, num_retvals);
|
||||
if (out_op) {
|
||||
return EagerLocalExecute(out_op.get(), retvals, num_retvals);
|
||||
} else {
|
||||
return EagerLocalExecute(op, retvals, num_retvals);
|
||||
}
|
||||
}
|
||||
|
||||
if (op->EagerContext()->LogDevicePlacement() || VLOG_IS_ON(1)) {
|
||||
@ -1050,7 +1059,11 @@ Status EagerExecute(EagerOperation* op,
|
||||
return errors::Unimplemented(
|
||||
"Eager's remote execution is not available on mobile devices.");
|
||||
#else // !IS_MOBILE_PLATFORM
|
||||
return EagerRemoteExecute(op, retvals->data(), num_retvals);
|
||||
if (out_op) {
|
||||
return EagerRemoteExecute(out_op.get(), retvals->data(), num_retvals);
|
||||
} else {
|
||||
return EagerRemoteExecute(op, retvals->data(), num_retvals);
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user