Merge pull request #29754 from Intel-tensorflow:eager_op_rewrite_registration

PiperOrigin-RevId: 256443805
This commit is contained in:
TensorFlower Gardener 2019-07-03 17:51:08 -07:00
commit 01abb79f05
5 changed files with 243 additions and 2 deletions

View File

@ -242,6 +242,7 @@ cc_library(
":context", ":context",
":copy_to_device_node", ":copy_to_device_node",
":eager_executor", ":eager_executor",
":eager_op_rewrite_registry",
":eager_operation", ":eager_operation",
":kernel_and_device", ":kernel_and_device",
":tensor_handle", ":tensor_handle",
@ -264,6 +265,24 @@ cc_library(
}), }),
) )
cc_library(
name = "eager_op_rewrite_registry",
srcs = ["eager_op_rewrite_registry.cc"],
hdrs = ["eager_op_rewrite_registry.h"],
deps = [":eager_operation"],
)
tf_cc_test(
name = "eager_op_rewrite_registry_test",
srcs = ["eager_op_rewrite_registry_test.cc"],
deps = [
":eager_op_rewrite_registry",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_library( tf_cuda_library(
name = "attr_builder", name = "attr_builder",
srcs = ["attr_builder.cc"], srcs = ["attr_builder.cc"],

View File

@ -0,0 +1,49 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
namespace tensorflow {
EagerOpRewriteRegistry* EagerOpRewriteRegistry::Global() {
static EagerOpRewriteRegistry* global_rewrite_registry =
new EagerOpRewriteRegistry;
return global_rewrite_registry;
}
void EagerOpRewriteRegistry::Register(Phase phase,
std::unique_ptr<EagerOpRewrite> pass) {
if (rewrites_.find(phase) == rewrites_.end()) {
rewrites_[phase] = std::move(pass);
} else {
TF_CHECK_OK(errors::AlreadyExists(pass->GetDebugInfo().name,
" is already registered as"
" EagerOpRewrite for this phase in ",
pass->GetDebugInfo().file, ":",
pass->GetDebugInfo().line));
}
}
Status EagerOpRewriteRegistry::RunRewrite(
Phase phase, EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>* out_op) {
auto rewrite = rewrites_.find(phase);
if (rewrite != rewrites_.end()) {
Status s = rewrite->second->Run(orig_op, out_op);
if (!s.ok()) return s;
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,103 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
#include <map>
#include <vector>
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
namespace tensorflow {
// Eager op rewrites should inherit from this class and
// implement the Run method.
class EagerOpRewrite {
public:
EagerOpRewrite(string name, string file, string line) {
debug_info_.name = name;
debug_info_.file = file;
debug_info_.line = line;
}
virtual ~EagerOpRewrite() {}
// To be implemnted by an Eager op rewrite pass.
virtual Status Run(EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>* out_op) = 0;
// Holds information about the rewrite registration.
struct DebugInfo {
string name, file, line;
};
// Returns information about the registered Eager op rewrite.
DebugInfo GetDebugInfo() const { return debug_info_; }
private:
DebugInfo debug_info_;
};
class EagerOpRewriteRegistry {
public:
// Phases at which the Eager op rewrite pass should run.
// For now we only added PRE_EXECUTION. Expand as needed.
enum Phase {
PRE_EXECUTION // right before executing an eager op
};
// Add a rewrite pass to the registry.
// Only one rewrite pass is allowed per phase.
void Register(Phase phase, std::unique_ptr<EagerOpRewrite> pass);
// Run the rewrite pass registered for a given phase.
Status RunRewrite(Phase phase, EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>* out_op);
// Returns the global registry of rewrite passes.
static EagerOpRewriteRegistry* Global();
private:
// Holds all the registered Eager op rewrites.
std::map<Phase, std::unique_ptr<EagerOpRewrite>> rewrites_;
};
namespace eager_rewrite_registration {
// This class is used to register a new Eager Op rewrite.
class EagerRewriteRegistration {
public:
EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase,
std::unique_ptr<EagerOpRewrite> pass) {
EagerOpRewriteRegistry::Global()->Register(phase, std::move(pass));
}
};
} // namespace eager_rewrite_registration
#define REGISTER_REWRITE(phase, rewrite) \
REGISTER_REWRITE_UNIQ_HELPER(__COUNTER__, __FILE__, __LINE__, phase, rewrite)
#define REGISTER_REWRITE_UNIQ_HELPER(ctr, file, line, phase, rewrite) \
REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite)
#define REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite) \
static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration \
register_rewrite_##ctr(phase, \
::std::unique_ptr<::tensorflow::EagerOpRewrite>( \
new rewrite(#rewrite, file, #line)))
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
class TestEagerOpRewrite : public EagerOpRewrite {
public:
TestEagerOpRewrite(string name, string file, string line)
: EagerOpRewrite(name, file, line) {}
static int count_;
Status Run(EagerOperation* orig_op,
std::unique_ptr<tensorflow::EagerOperation>* out_op) override {
++count_;
const tensorflow::AttrTypeMap* types;
bool is_function = false;
const string kNewOp = "NoOp";
TF_RETURN_IF_ERROR(
tensorflow::AttrTypeMapForOp(kNewOp.c_str(), &types, &is_function));
// Create a new NoOp Eager operation.
out_op->reset(new tensorflow::EagerOperation(nullptr, kNewOp.c_str(),
is_function, types));
return Status::OK();
}
};
int TestEagerOpRewrite::count_ = 0;
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, TestEagerOpRewrite);
TEST(EagerOpRewriteRegistryTest, RegisterRewritePass) {
EXPECT_EQ(0, TestEagerOpRewrite::count_);
EagerOperation* orig_op = nullptr;
std::unique_ptr<tensorflow::EagerOperation> out_op;
EXPECT_EQ(Status::OK(),
EagerOpRewriteRegistry::Global()->RunRewrite(
EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, &out_op));
EXPECT_EQ(1, TestEagerOpRewrite::count_);
EXPECT_EQ("NoOp", out_op->Name());
}
} // namespace tensorflow

View File

@ -57,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
namespace tensorflow { namespace tensorflow {
@ -1033,8 +1034,16 @@ Status EagerExecute(EagerOperation* op,
bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName()); bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName());
std::unique_ptr<tensorflow::EagerOperation> out_op;
TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
if (op_is_local) { if (op_is_local) {
return EagerLocalExecute(op, retvals, num_retvals); if (out_op) {
return EagerLocalExecute(out_op.get(), retvals, num_retvals);
} else {
return EagerLocalExecute(op, retvals, num_retvals);
}
} }
if (op->EagerContext()->LogDevicePlacement() || VLOG_IS_ON(1)) { if (op->EagerContext()->LogDevicePlacement() || VLOG_IS_ON(1)) {
@ -1050,7 +1059,11 @@ Status EagerExecute(EagerOperation* op,
return errors::Unimplemented( return errors::Unimplemented(
"Eager's remote execution is not available on mobile devices."); "Eager's remote execution is not available on mobile devices.");
#else // !IS_MOBILE_PLATFORM #else // !IS_MOBILE_PLATFORM
return EagerRemoteExecute(op, retvals->data(), num_retvals); if (out_op) {
return EagerRemoteExecute(out_op.get(), retvals->data(), num_retvals);
} else {
return EagerRemoteExecute(op, retvals->data(), num_retvals);
}
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }