diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 934fffabea9..5d5c93130dc 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -242,6 +242,7 @@ cc_library( ":context", ":copy_to_device_node", ":eager_executor", + ":eager_op_rewrite_registry", ":eager_operation", ":kernel_and_device", ":tensor_handle", @@ -264,6 +265,24 @@ cc_library( }), ) +cc_library( + name = "eager_op_rewrite_registry", + srcs = ["eager_op_rewrite_registry.cc"], + hdrs = ["eager_op_rewrite_registry.h"], + deps = [":eager_operation"], +) + +tf_cc_test( + name = "eager_op_rewrite_registry_test", + srcs = ["eager_op_rewrite_registry_test.cc"], + deps = [ + ":eager_op_rewrite_registry", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_library( name = "attr_builder", srcs = ["attr_builder.cc"], diff --git a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc new file mode 100644 index 00000000000..e910136a4f4 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" + +namespace tensorflow { + +EagerOpRewriteRegistry* EagerOpRewriteRegistry::Global() { + static EagerOpRewriteRegistry* global_rewrite_registry = + new EagerOpRewriteRegistry; + return global_rewrite_registry; +} + +void EagerOpRewriteRegistry::Register(Phase phase, + std::unique_ptr pass) { + if (rewrites_.find(phase) == rewrites_.end()) { + rewrites_[phase] = std::move(pass); + } else { + TF_CHECK_OK(errors::AlreadyExists(pass->GetDebugInfo().name, + " is already registered as" + " EagerOpRewrite for this phase in ", + pass->GetDebugInfo().file, ":", + pass->GetDebugInfo().line)); + } +} + +Status EagerOpRewriteRegistry::RunRewrite( + Phase phase, EagerOperation* orig_op, + std::unique_ptr* out_op) { + auto rewrite = rewrites_.find(phase); + if (rewrite != rewrites_.end()) { + Status s = rewrite->second->Run(orig_op, out_op); + if (!s.ok()) return s; + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h new file mode 100644 index 00000000000..58757fecea8 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h @@ -0,0 +1,103 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/eager/eager_operation.h" + +namespace tensorflow { + +// Eager op rewrites should inherit from this class and +// implement the Run method. +class EagerOpRewrite { + public: + EagerOpRewrite(string name, string file, string line) { + debug_info_.name = name; + debug_info_.file = file; + debug_info_.line = line; + } + + virtual ~EagerOpRewrite() {} + + // To be implemnted by an Eager op rewrite pass. + virtual Status Run(EagerOperation* orig_op, + std::unique_ptr* out_op) = 0; + + // Holds information about the rewrite registration. + struct DebugInfo { + string name, file, line; + }; + + // Returns information about the registered Eager op rewrite. + DebugInfo GetDebugInfo() const { return debug_info_; } + + private: + DebugInfo debug_info_; +}; + +class EagerOpRewriteRegistry { + public: + // Phases at which the Eager op rewrite pass should run. + // For now we only added PRE_EXECUTION. Expand as needed. + enum Phase { + PRE_EXECUTION // right before executing an eager op + }; + + // Add a rewrite pass to the registry. + // Only one rewrite pass is allowed per phase. + void Register(Phase phase, std::unique_ptr pass); + + // Run the rewrite pass registered for a given phase. + Status RunRewrite(Phase phase, EagerOperation* orig_op, + std::unique_ptr* out_op); + + // Returns the global registry of rewrite passes. + static EagerOpRewriteRegistry* Global(); + + private: + // Holds all the registered Eager op rewrites. + std::map> rewrites_; +}; + +namespace eager_rewrite_registration { + +// This class is used to register a new Eager Op rewrite. +class EagerRewriteRegistration { + public: + EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase, + std::unique_ptr pass) { + EagerOpRewriteRegistry::Global()->Register(phase, std::move(pass)); + } +}; + +} // namespace eager_rewrite_registration + +#define REGISTER_REWRITE(phase, rewrite) \ + REGISTER_REWRITE_UNIQ_HELPER(__COUNTER__, __FILE__, __LINE__, phase, rewrite) + +#define REGISTER_REWRITE_UNIQ_HELPER(ctr, file, line, phase, rewrite) \ + REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite) + +#define REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite) \ + static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration \ + register_rewrite_##ctr(phase, \ + ::std::unique_ptr<::tensorflow::EagerOpRewrite>( \ + new rewrite(#rewrite, file, #line))) + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ diff --git a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc new file mode 100644 index 00000000000..cfb11d870f5 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class TestEagerOpRewrite : public EagerOpRewrite { + public: + TestEagerOpRewrite(string name, string file, string line) + : EagerOpRewrite(name, file, line) {} + static int count_; + Status Run(EagerOperation* orig_op, + std::unique_ptr* out_op) override { + ++count_; + const tensorflow::AttrTypeMap* types; + bool is_function = false; + const string kNewOp = "NoOp"; + TF_RETURN_IF_ERROR( + tensorflow::AttrTypeMapForOp(kNewOp.c_str(), &types, &is_function)); + // Create a new NoOp Eager operation. + out_op->reset(new tensorflow::EagerOperation(nullptr, kNewOp.c_str(), + is_function, types)); + return Status::OK(); + } +}; + +int TestEagerOpRewrite::count_ = 0; + +REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, TestEagerOpRewrite); + +TEST(EagerOpRewriteRegistryTest, RegisterRewritePass) { + EXPECT_EQ(0, TestEagerOpRewrite::count_); + EagerOperation* orig_op = nullptr; + std::unique_ptr out_op; + EXPECT_EQ(Status::OK(), + EagerOpRewriteRegistry::Global()->RunRewrite( + EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, &out_op)); + EXPECT_EQ(1, TestEagerOpRewrite::count_); + EXPECT_EQ("NoOp", out_op->Name()); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index c83eed6c7aa..4294801e342 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -57,6 +57,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/ptr_util.h" +#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" namespace tensorflow { @@ -1033,8 +1034,16 @@ Status EagerExecute(EagerOperation* op, bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName()); + std::unique_ptr out_op; + TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite( + EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op)); + if (op_is_local) { - return EagerLocalExecute(op, retvals, num_retvals); + if (out_op) { + return EagerLocalExecute(out_op.get(), retvals, num_retvals); + } else { + return EagerLocalExecute(op, retvals, num_retvals); + } } if (op->EagerContext()->LogDevicePlacement() || VLOG_IS_ON(1)) { @@ -1050,7 +1059,11 @@ Status EagerExecute(EagerOperation* op, return errors::Unimplemented( "Eager's remote execution is not available on mobile devices."); #else // !IS_MOBILE_PLATFORM - return EagerRemoteExecute(op, retvals->data(), num_retvals); + if (out_op) { + return EagerRemoteExecute(out_op.get(), retvals->data(), num_retvals); + } else { + return EagerRemoteExecute(op, retvals->data(), num_retvals); + } #endif // !IS_MOBILE_PLATFORM }