Cleanup: Ran clang-format on files in tensorflow/core/.../*.{cc,h}.
PiperOrigin-RevId: 183848459
This commit is contained in:
		
							parent
							
								
									88eb6c61ef
								
							
						
					
					
						commit
						7149a2e2e2
					
				| @ -13,11 +13,9 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| #include "tensorflow/core/common_runtime/optimization_registry.h" | #include "tensorflow/core/common_runtime/optimization_registry.h" | ||||||
| #include "tensorflow/core/graph/node_builder.h" | #include "tensorflow/core/graph/node_builder.h" | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| @ -44,7 +42,6 @@ Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) { | |||||||
| // third-party libraries aren't currently supported.
 | // third-party libraries aren't currently supported.
 | ||||||
| class AccumulateNV2RemovePass : public GraphOptimizationPass { | class AccumulateNV2RemovePass : public GraphOptimizationPass { | ||||||
|  public: |  public: | ||||||
| 
 |  | ||||||
|   Status Run(const GraphOptimizationPassOptions& options) override { |   Status Run(const GraphOptimizationPassOptions& options) override { | ||||||
|     // TODO(freiss.oss@gmail.com): Substantial shared code with
 |     // TODO(freiss.oss@gmail.com): Substantial shared code with
 | ||||||
|     // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes
 |     // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes
 | ||||||
|  | |||||||
| @ -127,10 +127,10 @@ class BFCAllocator : public VisitableAllocator { | |||||||
|     string DebugString(BFCAllocator* a, |     string DebugString(BFCAllocator* a, | ||||||
|                        bool recurse) NO_THREAD_SAFETY_ANALYSIS { |                        bool recurse) NO_THREAD_SAFETY_ANALYSIS { | ||||||
|       string dbg; |       string dbg; | ||||||
|       strings::StrAppend(&dbg, "  Size: ", strings::HumanReadableNumBytes(size), |       strings::StrAppend( | ||||||
|                          " | Requested Size: ", |           &dbg, "  Size: ", strings::HumanReadableNumBytes(size), | ||||||
|                          strings::HumanReadableNumBytes(requested_size), |           " | Requested Size: ", strings::HumanReadableNumBytes(requested_size), | ||||||
|                          " | in_use: ", in_use()); |           " | in_use: ", in_use()); | ||||||
|       if (recurse && prev != BFCAllocator::kInvalidChunkHandle) { |       if (recurse && prev != BFCAllocator::kInvalidChunkHandle) { | ||||||
|         Chunk* p = a->ChunkFromHandle(prev); |         Chunk* p = a->ChunkFromHandle(prev); | ||||||
|         strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false)); |         strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false)); | ||||||
|  | |||||||
| @ -88,7 +88,9 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) { | |||||||
|   // D3 is prioritized below D1.
 |   // D3 is prioritized below D1.
 | ||||||
|   AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0"); |   AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0"); | ||||||
|   EXPECT_EQ((std::vector<DeviceType>{ |   EXPECT_EQ((std::vector<DeviceType>{ | ||||||
|                 DeviceType("d2"), DeviceType("d1"), DeviceType("d3"), |                 DeviceType("d2"), | ||||||
|  |                 DeviceType("d1"), | ||||||
|  |                 DeviceType("d3"), | ||||||
|             }), |             }), | ||||||
|             types()); |             types()); | ||||||
| } | } | ||||||
|  | |||||||
| @ -61,7 +61,6 @@ limitations under the License. | |||||||
| #include "tensorflow/core/util/device_name_utils.h" | #include "tensorflow/core/util/device_name_utils.h" | ||||||
| #include "tensorflow/core/util/env_var.h" | #include "tensorflow/core/util/env_var.h" | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| @ -472,9 +471,9 @@ Status DirectSession::Run(const RunOptions& run_options, | |||||||
|   Executor::Args args; |   Executor::Args args; | ||||||
|   args.step_id = step_id_counter_.fetch_add(1); |   args.step_id = step_id_counter_.fetch_add(1); | ||||||
| 
 | 
 | ||||||
|   TF_RETURN_IF_ERROR( |   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, | ||||||
|       GetOrCreateExecutors(input_tensor_names, output_names, target_nodes, |                                           target_nodes, &executors_and_keys, | ||||||
|                            &executors_and_keys, &run_state_args)); |                                           &run_state_args)); | ||||||
|   const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); |   const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); | ||||||
| 
 | 
 | ||||||
|   std::unique_ptr<DebuggerStateInterface> debugger_state; |   std::unique_ptr<DebuggerStateInterface> debugger_state; | ||||||
|  | |||||||
| @ -436,10 +436,7 @@ TEST(DirectSessionTest, FetchMultipleTimes) { | |||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| REGISTER_OP("Darth") | REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc( | ||||||
|     .Input("x: float") |  | ||||||
|     .Output("y: float") |  | ||||||
|     .Doc(R"doc( |  | ||||||
| Darth promises one return value. | Darth promises one return value. | ||||||
| 
 | 
 | ||||||
| x: float | x: float | ||||||
| @ -972,39 +969,38 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib, | |||||||
| 
 | 
 | ||||||
|   std::atomic<int32> num_done(0); |   std::atomic<int32> num_done(0); | ||||||
|   // Runs session to compute <node>:0 using inter_op thread pool <pool>.
 |   // Runs session to compute <node>:0 using inter_op thread pool <pool>.
 | ||||||
|   auto add_session_run_call = [use_global_pools, &def, &options, &sessions, |   auto add_session_run_call = | ||||||
|                                &sessions_mu, |       [use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done]( | ||||||
|                                &num_done](thread::ThreadPool* tp, Node* node, |           thread::ThreadPool* tp, Node* node, int inter_op_pool) { | ||||||
|                                           int inter_op_pool) { |         auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu, | ||||||
|     auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu, |                    inter_op_pool, node, &num_done]() { | ||||||
|                inter_op_pool, node, &num_done]() { |           RunOptions run_options; | ||||||
|       RunOptions run_options; |           run_options.set_inter_op_thread_pool(inter_op_pool); | ||||||
|       run_options.set_inter_op_thread_pool(inter_op_pool); |           std::vector<Tensor> outputs; | ||||||
|       std::vector<Tensor> outputs; |  | ||||||
| 
 | 
 | ||||||
|       Session* session; |           Session* session; | ||||||
|       if (use_global_pools) { |           if (use_global_pools) { | ||||||
|         std::unique_ptr<Session> s(NewSession(options)); |             std::unique_ptr<Session> s(NewSession(options)); | ||||||
|         TF_ASSERT_OK(s->Create(def)); |             TF_ASSERT_OK(s->Create(def)); | ||||||
|         session = s.get(); |             session = s.get(); | ||||||
| 
 | 
 | ||||||
|         mutex_lock l(sessions_mu); |             mutex_lock l(sessions_mu); | ||||||
|         sessions.emplace_back(std::move(s)); |             sessions.emplace_back(std::move(s)); | ||||||
|       } else { |           } else { | ||||||
|         session = sessions[0].get(); |             session = sessions[0].get(); | ||||||
|       } |           } | ||||||
| 
 | 
 | ||||||
|       Status s = session->Run(run_options, {} /* inputs */, |           Status s = session->Run(run_options, {} /* inputs */, | ||||||
|                               {node->name() + ":0"} /* output_names */, {}, |                                   {node->name() + ":0"} /* output_names */, {}, | ||||||
|                               &outputs, nullptr /* run_metadata */); |                                   &outputs, nullptr /* run_metadata */); | ||||||
|       TF_CHECK_OK(s); |           TF_CHECK_OK(s); | ||||||
|       ASSERT_EQ(1, outputs.size()); |           ASSERT_EQ(1, outputs.size()); | ||||||
|       auto flat = outputs[0].flat<float>(); |           auto flat = outputs[0].flat<float>(); | ||||||
|       EXPECT_FLOAT_EQ(1.2, flat(0)); |           EXPECT_FLOAT_EQ(1.2, flat(0)); | ||||||
|       num_done.fetch_add(1); |           num_done.fetch_add(1); | ||||||
|     }; |         }; | ||||||
|     tp->Schedule(fn); |         tp->Schedule(fn); | ||||||
|   }; |       }; | ||||||
| 
 | 
 | ||||||
|   // For blocking states:
 |   // For blocking states:
 | ||||||
|   // - Starts at 0, BlockingOp::Compute will move to 1.
 |   // - Starts at 0, BlockingOp::Compute will move to 1.
 | ||||||
|  | |||||||
| @ -161,14 +161,14 @@ static void TestHWAccelerator(bool enableHWTrace) { | |||||||
|   x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); |   x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); | ||||||
| #ifdef TENSORFLOW_USE_SYCL | #ifdef TENSORFLOW_USE_SYCL | ||||||
|   x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); |   x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); | ||||||
| #endif // TENSORFLOW_USE_SYCL
 | #endif  // TENSORFLOW_USE_SYCL
 | ||||||
| 
 | 
 | ||||||
|   // y = A * x
 |   // y = A * x
 | ||||||
|   Node* y = test::graph::Matmul(&graph, a, x, false, false); |   Node* y = test::graph::Matmul(&graph, a, x, false, false); | ||||||
|   y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); |   y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0"); | ||||||
| #ifdef TENSORFLOW_USE_SYCL | #ifdef TENSORFLOW_USE_SYCL | ||||||
| y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); |   y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); | ||||||
| #endif // TENSORFLOW_USE_SYCL
 | #endif  // TENSORFLOW_USE_SYCL
 | ||||||
| 
 | 
 | ||||||
|   Node* y_neg = test::graph::Unary(&graph, "Neg", y); |   Node* y_neg = test::graph::Unary(&graph, "Neg", y); | ||||||
|   y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); |   y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0"); | ||||||
| @ -181,7 +181,7 @@ y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0"); | |||||||
|   (*options.config.mutable_device_count())["GPU"] = 1; |   (*options.config.mutable_device_count())["GPU"] = 1; | ||||||
| #ifdef TENSORFLOW_USE_SYCL | #ifdef TENSORFLOW_USE_SYCL | ||||||
|   (*options.config.mutable_device_count())["SYCL"] = 1; |   (*options.config.mutable_device_count())["SYCL"] = 1; | ||||||
| #endif // TENSORFLOW_USE_SYCL
 | #endif  // TENSORFLOW_USE_SYCL
 | ||||||
|   options.config.set_allow_soft_placement(true); |   options.config.set_allow_soft_placement(true); | ||||||
|   options.config.mutable_graph_options()->set_build_cost_model(1); |   options.config.mutable_graph_options()->set_build_cost_model(1); | ||||||
|   std::unique_ptr<Session> session(NewSession(options)); |   std::unique_ptr<Session> session(NewSession(options)); | ||||||
|  | |||||||
| @ -1609,7 +1609,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { | |||||||
|         auto done = [this, state]() { |         auto done = [this, state]() { | ||||||
|           Device* device = impl_->params_.device; |           Device* device = impl_->params_.device; | ||||||
|           NodeExecStatsWrapper* stats = state->stats;  // Shorthand
 |           NodeExecStatsWrapper* stats = state->stats;  // Shorthand
 | ||||||
|           Entry* first_input = state->first_input;  // Shorthand
 |           Entry* first_input = state->first_input;     // Shorthand
 | ||||||
| 
 | 
 | ||||||
|           nodestats::SetOpEnd(stats); |           nodestats::SetOpEnd(stats); | ||||||
|           EntryVector outputs; |           EntryVector outputs; | ||||||
|  | |||||||
| @ -205,7 +205,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { | |||||||
|   // The instantiated and transformed function is encoded as a Graph
 |   // The instantiated and transformed function is encoded as a Graph
 | ||||||
|   // object, and an executor is created for the graph.
 |   // object, and an executor is created for the graph.
 | ||||||
|   struct Item : public core::RefCounted { |   struct Item : public core::RefCounted { | ||||||
|     const Graph* graph = nullptr;  // Owned by exec.
 |     const Graph* graph = nullptr;                            // Owned by exec.
 | ||||||
|     const FunctionLibraryDefinition* overlay_lib = nullptr;  // Not owned.
 |     const FunctionLibraryDefinition* overlay_lib = nullptr;  // Not owned.
 | ||||||
|     FunctionBody* func_graph = nullptr; |     FunctionBody* func_graph = nullptr; | ||||||
|     Executor* exec = nullptr; |     Executor* exec = nullptr; | ||||||
|  | |||||||
| @ -154,8 +154,9 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) { | |||||||
|     a.DeallocateRaw(t3); |     a.DeallocateRaw(t3); | ||||||
|     a.DeallocateRaw(t4); |     a.DeallocateRaw(t4); | ||||||
|   } |   } | ||||||
|   CheckStats(&a, 4097, 0, 1024 * sizeof(float) + 1048576 * sizeof(int64) + |   CheckStats(&a, 4097, 0, | ||||||
|                               2048 * sizeof(double) + 10485760 * sizeof(float), |              1024 * sizeof(float) + 1048576 * sizeof(int64) + | ||||||
|  |                  2048 * sizeof(double) + 10485760 * sizeof(float), | ||||||
|              10485760 * sizeof(float)); |              10485760 * sizeof(float)); | ||||||
| 
 | 
 | ||||||
|   // At the end, we should have coalesced all memory into one region
 |   // At the end, we should have coalesced all memory into one region
 | ||||||
|  | |||||||
| @ -763,8 +763,9 @@ int64 MinSystemMemory(int64 available_memory) { | |||||||
|   min_system_memory *= 2; |   min_system_memory *= 2; | ||||||
| #endif | #endif | ||||||
| #if defined(NVIDIA_TEGRA) | #if defined(NVIDIA_TEGRA) | ||||||
|   // 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM and Video RAM
 |   // 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM
 | ||||||
|   min_system_memory = 1<<30; |   // and Video RAM
 | ||||||
|  |   min_system_memory = 1 << 30; | ||||||
| #endif | #endif | ||||||
|   return min_system_memory; |   return min_system_memory; | ||||||
| } | } | ||||||
|  | |||||||
| @ -108,7 +108,8 @@ TEST_F(GpuStreamUtilTest, StreamOverrides) { | |||||||
|   ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0, |   ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0, | ||||||
|              "/device:GPU:0"); |              "/device:GPU:0"); | ||||||
|   Output n = ops::MatMul(root, {}, {}); |   Output n = ops::MatMul(root, {}, {}); | ||||||
|   ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, "/cpu:0"); |   ops::_Send(root.WithOpName("output"), n, "output", "/device:GPU:0", 0, | ||||||
|  |              "/cpu:0"); | ||||||
|   Graph g(OpRegistry::Global()); |   Graph g(OpRegistry::Global()); | ||||||
|   TF_ASSERT_OK(root.ToGraph(&g)); |   TF_ASSERT_OK(root.ToGraph(&g)); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -88,8 +88,8 @@ ProcessState::~ProcessState() { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| string ProcessState::MemDesc::DebugString() { | string ProcessState::MemDesc::DebugString() { | ||||||
|   return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ", |   return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, | ||||||
|                          gpu_registered, ", nic: ", nic_registered); |                          ", dma: ", gpu_registered, ", nic: ", nic_registered); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { | ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { | ||||||
|  | |||||||
| @ -139,9 +139,7 @@ class GraphExecutionState { | |||||||
| 
 | 
 | ||||||
|   // The graph returned by BuildGraph may contain only the pruned
 |   // The graph returned by BuildGraph may contain only the pruned
 | ||||||
|   // graph, whereas some clients may want access to the full graph.
 |   // graph, whereas some clients may want access to the full graph.
 | ||||||
|   const Graph* full_graph() { |   const Graph* full_graph() { return graph_; } | ||||||
|     return graph_; |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   // Returns the node with the given name, or null if it does not exist.
 |   // Returns the node with the given name, or null if it does not exist.
 | ||||||
|   const Node* get_node_by_name(const string& name) const { |   const Node* get_node_by_name(const string& name) const { | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ struct EndpointEq { | |||||||
| static Status ProcessMemoryTypes( | static Status ProcessMemoryTypes( | ||||||
|     const DeviceType& device_type, const Graph* g, |     const DeviceType& device_type, const Graph* g, | ||||||
|     const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) { |     const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) { | ||||||
|   if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL ) { |   if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL) { | ||||||
|     // On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
 |     // On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
 | ||||||
|     // compatible.
 |     // compatible.
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|  | |||||||
| @ -36,7 +36,7 @@ TEST(MemoryTypeChecker, Int32OK) { | |||||||
| #endif  // GOOGLE_CUDA
 | #endif  // GOOGLE_CUDA
 | ||||||
| #ifdef TENSORFLOW_USE_SYCL | #ifdef TENSORFLOW_USE_SYCL | ||||||
|   TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); |   TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); | ||||||
| #endif // TENSORFLOW_USE_SYCL
 | #endif  // TENSORFLOW_USE_SYCL
 | ||||||
|   delete g; |   delete g; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -64,7 +64,7 @@ TEST(MemoryTypeChecker, Int32NotOk) { | |||||||
|   // But we can insert _HostSend/_HostRecv to ensure the invariant.
 |   // But we can insert _HostSend/_HostRecv to ensure the invariant.
 | ||||||
|   TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g)); |   TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g)); | ||||||
|   TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); |   TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g)); | ||||||
| #endif // TENSORFLOW_USE_SYCL
 | #endif  // TENSORFLOW_USE_SYCL
 | ||||||
|   delete g; |   delete g; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -91,7 +91,7 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) { | |||||||
|   TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type)); |   TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type)); | ||||||
|   // int Switch's output on GPU has HOST_MEMORY constraint.
 |   // int Switch's output on GPU has HOST_MEMORY constraint.
 | ||||||
|   EXPECT_EQ(memory_type, HOST_MEMORY); |   EXPECT_EQ(memory_type, HOST_MEMORY); | ||||||
| #endif // TENSORFLOW_USE_SYCL
 | #endif  // TENSORFLOW_USE_SYCL
 | ||||||
|   delete g; |   delete g; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -88,9 +88,9 @@ class Placer { | |||||||
|   void AssignAndLog(int assigned_device, Node* node) const; |   void AssignAndLog(int assigned_device, Node* node) const; | ||||||
|   void LogDeviceAssignment(const Node* node) const; |   void LogDeviceAssignment(const Node* node) const; | ||||||
| 
 | 
 | ||||||
|   Graph* const graph_;                           // Not owned.
 |   Graph* const graph_;              // Not owned.
 | ||||||
|   const DeviceSet* const devices_;               // Not owned.
 |   const DeviceSet* const devices_;  // Not owned.
 | ||||||
|   const SessionOptions* options_;                // Not owned.
 |   const SessionOptions* options_;   // Not owned.
 | ||||||
|   const bool log_device_placement_; |   const bool log_device_placement_; | ||||||
| 
 | 
 | ||||||
|   TF_DISALLOW_COPY_AND_ASSIGN(Placer); |   TF_DISALLOW_COPY_AND_ASSIGN(Placer); | ||||||
|  | |||||||
| @ -619,9 +619,9 @@ TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) { | |||||||
|     Node* input = ops::SourceOp( |     Node* input = ops::SourceOp( | ||||||
|         "TestDevice", |         "TestDevice", | ||||||
|         b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); |         b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); | ||||||
|     Node* var = ops::SourceOp("TestVariable", |     Node* var = | ||||||
|                               b.opts().WithName("var_0").WithDevice( |         ops::SourceOp("TestVariable", b.opts().WithName("var_0").WithDevice( | ||||||
|                                   "/job:a/task:0/device:fakegpu:0")); |                                           "/job:a/task:0/device:fakegpu:0")); | ||||||
| 
 | 
 | ||||||
|     // This op is specified on CPU, but in practice will be ignored,
 |     // This op is specified on CPU, but in practice will be ignored,
 | ||||||
|     // because the reference edges forces it on GPU.
 |     // because the reference edges forces it on GPU.
 | ||||||
|  | |||||||
| @ -60,8 +60,8 @@ const string RegisteredFactoriesErrorMessageLocked() { | |||||||
|                          str_util::Join(factory_types, ", "), "}."); |                          str_util::Join(factory_types, ", "), "}."); | ||||||
| } | } | ||||||
| string SessionOptionsToString(const SessionOptions& options) { | string SessionOptionsToString(const SessionOptions& options) { | ||||||
|   return strings::StrCat("target: \"", options.target, "\" config: ", |   return strings::StrCat("target: \"", options.target, | ||||||
|                          ProtoShortDebugString(options.config)); |                          "\" config: ", ProtoShortDebugString(options.config)); | ||||||
| } | } | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -226,22 +226,23 @@ void StepStatsCollector::BuildCostModel( | |||||||
|       if (node) { |       if (node) { | ||||||
|         for (int i = 0; i < stats.output_size(); ++i) { |         for (int i = 0; i < stats.output_size(); ++i) { | ||||||
|           const auto& output = stats.output(i); |           const auto& output = stats.output(i); | ||||||
|           cm->RecordMaxMemorySize(node, i, Bytes(output.tensor_description() |           cm->RecordMaxMemorySize(node, i, | ||||||
|                                                      .allocation_description() |                                   Bytes(output.tensor_description() | ||||||
|                                                      .allocated_bytes()), |                                             .allocation_description() | ||||||
|  |                                             .allocated_bytes()), | ||||||
|                                   stats.output(i).tensor_description().shape(), |                                   stats.output(i).tensor_description().shape(), | ||||||
|                                   node->output_types()[i]); |                                   node->output_types()[i]); | ||||||
|           cm->RecordAllocationId(node, i, output.tensor_description() |           cm->RecordAllocationId(node, i, | ||||||
|                                               .allocation_description() |                                  output.tensor_description() | ||||||
|                                               .allocation_id()); |                                      .allocation_description() | ||||||
|  |                                      .allocation_id()); | ||||||
|         } |         } | ||||||
|         cm->RecordMemoryStats(node, stats.memory_stats()); |         cm->RecordMemoryStats(node, stats.memory_stats()); | ||||||
|         // Use hardware stats to record the execution time if they're available,
 |         // Use hardware stats to record the execution time if they're available,
 | ||||||
|         // otherwise use the regular (less accurate) stats
 |         // otherwise use the regular (less accurate) stats
 | ||||||
|         string node_name = dev_stats.regular_stats->node_stats(i).node_name(); |         string node_name = dev_stats.regular_stats->node_stats(i).node_name(); | ||||||
|         if (dev_stats.hardware_stats && |         if (dev_stats.hardware_stats && name_to_hw_node_stats.find(node_name) != | ||||||
|             name_to_hw_node_stats.find(node_name) != |                                             name_to_hw_node_stats.end()) { | ||||||
|                 name_to_hw_node_stats.end()) { |  | ||||||
|           const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name]; |           const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name]; | ||||||
|           cm->RecordMaxExecutionTime( |           cm->RecordMaxExecutionTime( | ||||||
|               node, Microseconds(hw_stats.op_end_rel_micros())); |               node, Microseconds(hw_stats.op_end_rel_micros())); | ||||||
|  | |||||||
| @ -80,7 +80,7 @@ void SYCLAllocator::ClearStats() override { | |||||||
| 
 | 
 | ||||||
| size_t SYCLAllocator::RequestedSize(void* ptr) { | size_t SYCLAllocator::RequestedSize(void* ptr) { | ||||||
|   mutex_lock lock(mu_); |   mutex_lock lock(mu_); | ||||||
|   if(!sycl_device_) { |   if (!sycl_device_) { | ||||||
|     return 0; |     return 0; | ||||||
|   } |   } | ||||||
|   const auto& buffer = sycl_device_->get_sycl_buffer(ptr); |   const auto& buffer = sycl_device_->get_sycl_buffer(ptr); | ||||||
|  | |||||||
| @ -20,10 +20,10 @@ limitations under the License. | |||||||
| #ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ | #ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ | ||||||
| #define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ | #define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ | ||||||
| 
 | 
 | ||||||
|  | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | ||||||
| #include "tensorflow/core/framework/allocator.h" | #include "tensorflow/core/framework/allocator.h" | ||||||
| #include "tensorflow/core/platform/mutex.h" | #include "tensorflow/core/platform/mutex.h" | ||||||
| #include "tensorflow/core/platform/types.h" | #include "tensorflow/core/platform/types.h" | ||||||
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |  | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| 
 | 
 | ||||||
| @ -56,14 +56,13 @@ class SYCLAllocator : public Allocator { | |||||||
|   // Clear the SYCL device used by the Allocator
 |   // Clear the SYCL device used by the Allocator
 | ||||||
|   void ClearSYCLDevice() { |   void ClearSYCLDevice() { | ||||||
|     mutex_lock lock(mu_); |     mutex_lock lock(mu_); | ||||||
|     if(sycl_device_) { |     if (sycl_device_) { | ||||||
|       delete sycl_device_; |       delete sycl_device_; | ||||||
|       sycl_device_ = nullptr; |       sycl_device_ = nullptr; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
| 
 |  | ||||||
|   mutable mutex mu_; |   mutable mutex mu_; | ||||||
|   Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_);  // owned
 |   Eigen::SyclDevice* sycl_device_ GUARDED_BY(mu_);  // owned
 | ||||||
|   AllocatorStats stats_ GUARDED_BY(mu_); |   AllocatorStats stats_ GUARDED_BY(mu_); | ||||||
|  | |||||||
| @ -187,9 +187,9 @@ class GSYCLInterface { | |||||||
|       type = "Unknown"; |       type = "Unknown"; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     return strings::StrCat("id: ", device_id, ", type: ", type, ", name: ", |     return strings::StrCat( | ||||||
|                            name.c_str(), ", vendor: ", vendor.c_str(), |         "id: ", device_id, ", type: ", type, ", name: ", name.c_str(), | ||||||
|                            ", profile: ", profile.c_str()); |         ", vendor: ", vendor.c_str(), ", profile: ", profile.c_str()); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -26,7 +26,6 @@ class SYCLDeviceFactory : public DeviceFactory { | |||||||
|  public: |  public: | ||||||
|   Status CreateDevices(const SessionOptions &options, const string &name_prefix, |   Status CreateDevices(const SessionOptions &options, const string &name_prefix, | ||||||
|                        std::vector<Device *> *devices) override { |                        std::vector<Device *> *devices) override { | ||||||
| 
 |  | ||||||
|     auto syclInterface = GSYCLInterface::instance(); |     auto syclInterface = GSYCLInterface::instance(); | ||||||
| 
 | 
 | ||||||
|     size_t n = 1; |     size_t n = 1; | ||||||
| @ -37,13 +36,11 @@ class SYCLDeviceFactory : public DeviceFactory { | |||||||
| 
 | 
 | ||||||
|     for (int i = 0; i < n; i++) { |     for (int i = 0; i < n; i++) { | ||||||
|       string name = strings::StrCat(name_prefix, "/device:SYCL:", i); |       string name = strings::StrCat(name_prefix, "/device:SYCL:", i); | ||||||
|       devices->push_back( |       devices->push_back(new SYCLDevice( | ||||||
|           new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality() |           options, name, Bytes(256 << 20), DeviceLocality(), | ||||||
|                          , syclInterface->GetShortDeviceDescription(i) |           syclInterface->GetShortDeviceDescription(i), | ||||||
|                          , syclInterface->GetSYCLAllocator(i) |           syclInterface->GetSYCLAllocator(i), syclInterface->GetCPUAllocator(i), | ||||||
|                          , syclInterface->GetCPUAllocator(i) |           syclInterface->GetSYCLContext(i))); | ||||||
|                          , syclInterface->GetSYCLContext(i)) |  | ||||||
|                        ); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| @ -51,6 +48,6 @@ class SYCLDeviceFactory : public DeviceFactory { | |||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200); | REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200); | ||||||
| } | }  // namespace tensorflow
 | ||||||
| 
 | 
 | ||||||
| #endif  // TENSORFLOW_USE_SYCL
 | #endif  // TENSORFLOW_USE_SYCL
 | ||||||
|  | |||||||
| @ -20,8 +20,8 @@ limitations under the License. | |||||||
| #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ | ||||||
| #define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ | #define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/common_runtime/device.h" |  | ||||||
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | ||||||
|  | #include "tensorflow/core/common_runtime/device.h" | ||||||
| // For DMA helper
 | // For DMA helper
 | ||||||
| #include "tensorflow/core/common_runtime/dma_helper.h" | #include "tensorflow/core/common_runtime/dma_helper.h" | ||||||
| #include "tensorflow/core/framework/tensor.h" | #include "tensorflow/core/framework/tensor.h" | ||||||
|  | |||||||
| @ -24,31 +24,31 @@ limitations under the License. | |||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| 
 | 
 | ||||||
| DebugGateway::DebugGateway(DirectSession* session) : session_(session) { | DebugGateway::DebugGateway(DirectSession* session) : session_(session) { | ||||||
|   session_->node_outputs_callback_ = [this]( |   session_->node_outputs_callback_ = | ||||||
|       const string& node_name, const int output_slot, const Tensor* tensor, |       [this](const string& node_name, const int output_slot, | ||||||
|       const bool is_ref, OpKernelContext* ctx) { |              const Tensor* tensor, const bool is_ref, OpKernelContext* ctx) { | ||||||
|     if (comp_cb_ != nullptr && output_slot <= 0) { |         if (comp_cb_ != nullptr && output_slot <= 0) { | ||||||
|       // The node completion callback is invoked once for a node regardless
 |           // The node completion callback is invoked once for a node regardless
 | ||||||
|       // of whether the node has zero, one or more outputs.
 |           // of whether the node has zero, one or more outputs.
 | ||||||
|       // The output_slot can be negative (-1, or kControlSlot) if
 |           // The output_slot can be negative (-1, or kControlSlot) if
 | ||||||
|       // node_outputs_callback_ is invoked for a node with no output. If that
 |           // node_outputs_callback_ is invoked for a node with no output. If
 | ||||||
|       // is the case, notify the callback that the node in question has no
 |           // that is the case, notify the callback that the node in question has
 | ||||||
|       // output.
 |           // no output.
 | ||||||
|       comp_cb_(node_name, output_slot == 0); |           comp_cb_(node_name, output_slot == 0); | ||||||
|     } |         } | ||||||
| 
 | 
 | ||||||
|     // Copy tensor values (e.g., from GPU to host) only if the
 |         // Copy tensor values (e.g., from GPU to host) only if the
 | ||||||
|     // value callback is not nullptr.
 |         // value callback is not nullptr.
 | ||||||
|     if (val_cb_ != nullptr && output_slot >= 0) { |         if (val_cb_ != nullptr && output_slot >= 0) { | ||||||
|       CopyTensor( |           CopyTensor(node_name, output_slot, tensor, ctx, | ||||||
|           node_name, output_slot, tensor, ctx, |                      [this, node_name, output_slot, | ||||||
|           [this, node_name, output_slot, is_ref](const Tensor* copied_tensor) { |                       is_ref](const Tensor* copied_tensor) { | ||||||
|             val_cb_(node_name, output_slot, *copied_tensor, is_ref); |                        val_cb_(node_name, output_slot, *copied_tensor, is_ref); | ||||||
|           }); |                      }); | ||||||
|     } |         } | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |         return Status::OK(); | ||||||
|   }; |       }; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| DebugGateway::~DebugGateway() { | DebugGateway::~DebugGateway() { | ||||||
| @ -86,7 +86,8 @@ void DebugGateway::CopyTensor(const string& node_name, const int output_slot, | |||||||
|     // Determine if the tensor is on device (GPU) or host (CPU).
 |     // Determine if the tensor is on device (GPU) or host (CPU).
 | ||||||
|     // The second part of the check is necessary because even an OpKernel on
 |     // The second part of the check is necessary because even an OpKernel on
 | ||||||
|     // may have output tensors allocated on CPU.
 |     // may have output tensors allocated on CPU.
 | ||||||
|     if ((device->name().find("GPU:") != string::npos || device->name().find("SYCL:") != string::npos) && |     if ((device->name().find("GPU:") != string::npos || | ||||||
|  |          device->name().find("SYCL:") != string::npos) && | ||||||
|         !ctx->output_alloc_attr(output_slot).on_host()) { |         !ctx->output_alloc_attr(output_slot).on_host()) { | ||||||
|       // GPU tensors: Copy it to host (CPU).
 |       // GPU tensors: Copy it to host (CPU).
 | ||||||
|       DeviceContext* device_ctxt = ctx->op_device_context(); |       DeviceContext* device_ctxt = ctx->op_device_context(); | ||||||
|  | |||||||
| @ -390,9 +390,9 @@ TEST_F(SessionDebugMinusAXTest, | |||||||
|   debug_gateway.SetNodeValueCallback( |   debug_gateway.SetNodeValueCallback( | ||||||
|       [this, &mu, &val_callback_count, &a_debug_identity_node_name, |       [this, &mu, &val_callback_count, &a_debug_identity_node_name, | ||||||
|        &x_debug_identity_node_name, &y_debug_identity_node_name, |        &x_debug_identity_node_name, &y_debug_identity_node_name, | ||||||
|        &debug_identity_tensor_vals, &callbacks_done, &kConcurrentRuns]( |        &debug_identity_tensor_vals, &callbacks_done, | ||||||
|            const string& node_name, const int output_slot, |        &kConcurrentRuns](const string& node_name, const int output_slot, | ||||||
|            const Tensor& tensor_value, const bool is_ref) { |                          const Tensor& tensor_value, const bool is_ref) { | ||||||
|         mutex_lock l(mu); |         mutex_lock l(mu); | ||||||
| 
 | 
 | ||||||
|         if (node_name == a_debug_identity_node_name && output_slot == 0) { |         if (node_name == a_debug_identity_node_name && output_slot == 0) { | ||||||
| @ -560,21 +560,21 @@ TEST_F(SessionDebugOutputSlotWithoutOutgoingEdgeTest, | |||||||
|   Notification callbacks_done; |   Notification callbacks_done; | ||||||
| 
 | 
 | ||||||
|   std::vector<Tensor> debug_identity_tensor_vals; |   std::vector<Tensor> debug_identity_tensor_vals; | ||||||
|   debug_gateway.SetNodeValueCallback([this, &mu, &callbacks_done, |   debug_gateway.SetNodeValueCallback( | ||||||
|                                       &debug_identity_node_name, |       [this, &mu, &callbacks_done, &debug_identity_node_name, | ||||||
|                                       &debug_identity_tensor_vals]( |        &debug_identity_tensor_vals]( | ||||||
|       const string& node_name, const int output_slot, |           const string& node_name, const int output_slot, | ||||||
|       const Tensor& tensor_value, const bool is_ref) { |           const Tensor& tensor_value, const bool is_ref) { | ||||||
|     mutex_lock l(mu); |         mutex_lock l(mu); | ||||||
| 
 | 
 | ||||||
|     if (node_name == debug_identity_node_name && output_slot == 0) { |         if (node_name == debug_identity_node_name && output_slot == 0) { | ||||||
|       debug_identity_tensor_vals.push_back(tensor_value); |           debug_identity_tensor_vals.push_back(tensor_value); | ||||||
| 
 | 
 | ||||||
|       if (!callbacks_done.HasBeenNotified()) { |           if (!callbacks_done.HasBeenNotified()) { | ||||||
|         callbacks_done.Notify(); |             callbacks_done.Notify(); | ||||||
|       } |           } | ||||||
|     } |         } | ||||||
|   }); |       }); | ||||||
| 
 | 
 | ||||||
|   // Add DebugIdentity watch on c:0, which does not have an outgoing edge.
 |   // Add DebugIdentity watch on c:0, which does not have an outgoing edge.
 | ||||||
|   RunOptions run_opts; |   RunOptions run_opts; | ||||||
|  | |||||||
| @ -30,7 +30,7 @@ namespace test { | |||||||
| 
 | 
 | ||||||
| ::grpc::Status TestEventListenerImpl::SendEvents( | ::grpc::Status TestEventListenerImpl::SendEvents( | ||||||
|     ::grpc::ServerContext* context, |     ::grpc::ServerContext* context, | ||||||
|     ::grpc::ServerReaderWriter< ::tensorflow::EventReply, ::tensorflow::Event>* |     ::grpc::ServerReaderWriter<::tensorflow::EventReply, ::tensorflow::Event>* | ||||||
|         stream) { |         stream) { | ||||||
|   Event event; |   Event event; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -57,7 +57,8 @@ class DebugIOUtilsTest : public ::testing::Test { | |||||||
| TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) { | TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) { | ||||||
|   DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2", |   DebugNodeKey debug_node_key("/job:worker/replica:1/task:0/device:GPU:2", | ||||||
|                               "hidden_1/MatMul", 0, "DebugIdentity"); |                               "hidden_1/MatMul", 0, "DebugIdentity"); | ||||||
|   EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", debug_node_key.device_name); |   EXPECT_EQ("/job:worker/replica:1/task:0/device:GPU:2", | ||||||
|  |             debug_node_key.device_name); | ||||||
|   EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name); |   EXPECT_EQ("hidden_1/MatMul", debug_node_key.node_name); | ||||||
|   EXPECT_EQ(0, debug_node_key.output_slot); |   EXPECT_EQ(0, debug_node_key.output_slot); | ||||||
|   EXPECT_EQ("DebugIdentity", debug_node_key.debug_op); |   EXPECT_EQ("DebugIdentity", debug_node_key.debug_op); | ||||||
|  | |||||||
| @ -140,7 +140,7 @@ class GraphMgr { | |||||||
|     GraphMgr* graph_mgr; |     GraphMgr* graph_mgr; | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|   const WorkerEnv* worker_env_;             // Not owned.
 |   const WorkerEnv* worker_env_;  // Not owned.
 | ||||||
|   DeviceMgr* device_mgr_; |   DeviceMgr* device_mgr_; | ||||||
| 
 | 
 | ||||||
|   CostModelManager cost_model_manager_; |   CostModelManager cost_model_manager_; | ||||||
|  | |||||||
| @ -528,8 +528,8 @@ void Master::ListDevices(const ListDevicesRequest* req, | |||||||
|       auto session = FindMasterSession(req->session_handle()); |       auto session = FindMasterSession(req->session_handle()); | ||||||
|       if (session == nullptr) { |       if (session == nullptr) { | ||||||
|         done(errors::InvalidArgument( |         done(errors::InvalidArgument( | ||||||
|              "Session ", req->session_handle(), |             "Session ", req->session_handle(), | ||||||
|              " is not found. Possibly, this master has restarted.")); |             " is not found. Possibly, this master has restarted.")); | ||||||
|         return; |         return; | ||||||
|       } |       } | ||||||
|       core::ScopedUnref ref(session); |       core::ScopedUnref ref(session); | ||||||
|  | |||||||
| @ -61,7 +61,7 @@ class MasterTest : public ::testing::Test { | |||||||
|   // rpc calls.
 |   // rpc calls.
 | ||||||
| 
 | 
 | ||||||
|   Status CreateSession(const GraphDef& def, string* handle, |   Status CreateSession(const GraphDef& def, string* handle, | ||||||
|                             int64* initial_version) { |                        int64* initial_version) { | ||||||
|     ::grpc::ClientContext ctx; |     ::grpc::ClientContext ctx; | ||||||
|     CreateSessionRequest req; |     CreateSessionRequest req; | ||||||
|     *(req.mutable_graph_def()) = def; |     *(req.mutable_graph_def()) = def; | ||||||
| @ -77,7 +77,7 @@ class MasterTest : public ::testing::Test { | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   Status ExtendSession(const string& handle, const GraphDef& def, |   Status ExtendSession(const string& handle, const GraphDef& def, | ||||||
|                             int64 current_version, int64* new_version) { |                        int64 current_version, int64* new_version) { | ||||||
|     ::grpc::ClientContext ctx; |     ::grpc::ClientContext ctx; | ||||||
|     ExtendSessionRequest req; |     ExtendSessionRequest req; | ||||||
|     req.set_session_handle(handle); |     req.set_session_handle(handle); | ||||||
|  | |||||||
| @ -185,23 +185,22 @@ class GrpcMasterService : public AsyncServiceInterface { | |||||||
|     MutableRunStepResponseWrapper* wrapped_response = |     MutableRunStepResponseWrapper* wrapped_response = | ||||||
|         new NonOwnedProtoRunStepResponse(&call->response); |         new NonOwnedProtoRunStepResponse(&call->response); | ||||||
|     call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); |     call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); | ||||||
|     master_impl_->RunStep(call_opts, wrapped_request, wrapped_response, |     master_impl_->RunStep( | ||||||
|                           [call, call_opts, wrapped_request, wrapped_response, |         call_opts, wrapped_request, wrapped_response, | ||||||
|                            trace](const Status& status) { |         [call, call_opts, wrapped_request, wrapped_response, | ||||||
|                             call->ClearCancelCallback(); |          trace](const Status& status) { | ||||||
|                             delete call_opts; |           call->ClearCancelCallback(); | ||||||
|                             delete wrapped_request; |           delete call_opts; | ||||||
|                             delete trace; |           delete wrapped_request; | ||||||
|                             if (call->request.store_errors_in_response_body() && |           delete trace; | ||||||
|                                 !status.ok()) { |           if (call->request.store_errors_in_response_body() && !status.ok()) { | ||||||
|                               call->response.set_status_code(status.code()); |             call->response.set_status_code(status.code()); | ||||||
|                               call->response.set_status_error_message( |             call->response.set_status_error_message(status.error_message()); | ||||||
|                                   status.error_message()); |             call->SendResponse(ToGrpcStatus(Status::OK())); | ||||||
|                               call->SendResponse(ToGrpcStatus(Status::OK())); |           } else { | ||||||
|                             } else { |             call->SendResponse(ToGrpcStatus(status)); | ||||||
|                               call->SendResponse(ToGrpcStatus(status)); |           } | ||||||
|                             } |         }); | ||||||
|                           }); |  | ||||||
|     ENQUEUE_REQUEST(RunStep, true); |     ENQUEUE_REQUEST(RunStep, true); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -89,9 +89,9 @@ class MasterService final { | |||||||
|     ::grpc::Status ExtendSession(::grpc::ClientContext* context, |     ::grpc::Status ExtendSession(::grpc::ClientContext* context, | ||||||
|                                  const ExtendSessionRequest& request, |                                  const ExtendSessionRequest& request, | ||||||
|                                  ExtendSessionResponse* response) override; |                                  ExtendSessionResponse* response) override; | ||||||
|     ::grpc::Status PartialRunSetup( |     ::grpc::Status PartialRunSetup(::grpc::ClientContext* context, | ||||||
|         ::grpc::ClientContext* context, const PartialRunSetupRequest& request, |                                    const PartialRunSetupRequest& request, | ||||||
|         PartialRunSetupResponse* response) override; |                                    PartialRunSetupResponse* response) override; | ||||||
|     ::grpc::Status RunStep(::grpc::ClientContext* context, |     ::grpc::Status RunStep(::grpc::ClientContext* context, | ||||||
|                            const RunStepRequest& request, |                            const RunStepRequest& request, | ||||||
|                            RunStepResponse* response) override; |                            RunStepResponse* response) override; | ||||||
|  | |||||||
| @ -69,8 +69,7 @@ class GrpcRemoteMaster : public MasterInterface { | |||||||
|     ::grpc::ClientContext ctx; |     ::grpc::ClientContext ctx; | ||||||
|     auto trace = TraceRpc("RunStep/Client", &ctx); |     auto trace = TraceRpc("RunStep/Client", &ctx); | ||||||
|     return Call(&ctx, call_options, &request->ToProto(), |     return Call(&ctx, call_options, &request->ToProto(), | ||||||
|                 get_proto_from_wrapper(response), |                 get_proto_from_wrapper(response), &MasterServiceStub::RunStep); | ||||||
|                 &MasterServiceStub::RunStep); |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   Status CloseSession(CallOptions* call_options, |   Status CloseSession(CallOptions* call_options, | ||||||
| @ -114,8 +113,9 @@ class GrpcRemoteMaster : public MasterInterface { | |||||||
|   template <typename Request, typename Response> |   template <typename Request, typename Response> | ||||||
|   Status Call(::grpc::ClientContext* ctx, CallOptions* call_options, |   Status Call(::grpc::ClientContext* ctx, CallOptions* call_options, | ||||||
|               const Request* request, Response* response, |               const Request* request, Response* response, | ||||||
|               ::grpc::Status (MasterServiceStub::*pfunc)( |               ::grpc::Status (MasterServiceStub::*pfunc)(::grpc::ClientContext*, | ||||||
|                   ::grpc::ClientContext*, const Request&, Response*)) { |                                                          const Request&, | ||||||
|  |                                                          Response*)) { | ||||||
|     ctx->set_fail_fast(false); |     ctx->set_fail_fast(false); | ||||||
|     SetDeadline(ctx, call_options->GetTimeout()); |     SetDeadline(ctx, call_options->GetTimeout()); | ||||||
|     return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response)); |     return FromGrpcStatus((stub_.get()->*pfunc)(ctx, *request, response)); | ||||||
|  | |||||||
| @ -21,11 +21,8 @@ namespace tensorflow { | |||||||
| namespace test { | namespace test { | ||||||
| 
 | 
 | ||||||
| // ErrorOp::Compute returns an error.
 | // ErrorOp::Compute returns an error.
 | ||||||
| REGISTER_OP("Error") | REGISTER_OP("Error").Input("in: T").Output("out: T").Attr("T: type").Attr( | ||||||
|     .Input("in: T") |     "message: string"); | ||||||
|     .Output("out: T") |  | ||||||
|     .Attr("T: type") |  | ||||||
|     .Attr("message: string"); |  | ||||||
| class ErrorOp : public OpKernel { | class ErrorOp : public OpKernel { | ||||||
|  public: |  public: | ||||||
|   explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |   explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { | ||||||
| @ -66,11 +63,8 @@ REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU), | |||||||
| 
 | 
 | ||||||
| // DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
 | // DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
 | ||||||
| // its input.
 | // its input.
 | ||||||
| REGISTER_OP("Delay") | REGISTER_OP("Delay").Input("in: T").Output("out: T").Attr("T: type").Attr( | ||||||
|     .Input("in: T") |     "micros: int"); | ||||||
|     .Output("out: T") |  | ||||||
|     .Attr("T: type") |  | ||||||
|     .Attr("micros: int"); |  | ||||||
| class DelayOp : public AsyncOpKernel { | class DelayOp : public AsyncOpKernel { | ||||||
|  public: |  public: | ||||||
|   explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { |   explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { | ||||||
|  | |||||||
| @ -184,8 +184,8 @@ static void BM_Helper(int iters, int width, int num_stages, int tensor_size, | |||||||
| 
 | 
 | ||||||
|   testing::SetLabel( |   testing::SetLabel( | ||||||
|       strings::StrCat(def.node_size(), " nodes; ", |       strings::StrCat(def.node_size(), " nodes; ", | ||||||
|              use_multiple_devices ? "Multi device" : "Single device", |                       use_multiple_devices ? "Multi device" : "Single device", | ||||||
|              "; tensor bytes/send: ", tensor_size * sizeof(float))); |                       "; tensor bytes/send: ", tensor_size * sizeof(float))); | ||||||
| 
 | 
 | ||||||
|   std::vector<Tensor> outputs; |   std::vector<Tensor> outputs; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -17,9 +17,9 @@ limitations under the License. | |||||||
| 
 | 
 | ||||||
| #include <queue> | #include <queue> | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/graph/graph.h" |  | ||||||
| #include "tensorflow/core/common_runtime/device.h" | #include "tensorflow/core/common_runtime/device.h" | ||||||
| #include "tensorflow/core/common_runtime/device_set.h" | #include "tensorflow/core/common_runtime/device_set.h" | ||||||
|  | #include "tensorflow/core/graph/graph.h" | ||||||
| #include "tensorflow/core/util/util.h" | #include "tensorflow/core/util/util.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
|  | |||||||
| @ -16,15 +16,15 @@ limitations under the License. | |||||||
| #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ | ||||||
| #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ | ||||||
| 
 | 
 | ||||||
| #include <functional> |  | ||||||
| #include <deque> | #include <deque> | ||||||
|  | #include <functional> | ||||||
| #include <map> | #include <map> | ||||||
| #include <unordered_map> | #include <unordered_map> | ||||||
| #include <vector> | #include <vector> | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/graph/costmodel.h" |  | ||||||
| #include "tensorflow/core/common_runtime/device.h" | #include "tensorflow/core/common_runtime/device.h" | ||||||
| #include "tensorflow/core/common_runtime/device_set.h" | #include "tensorflow/core/common_runtime/device_set.h" | ||||||
|  | #include "tensorflow/core/graph/costmodel.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -97,9 +97,8 @@ void WorkerCacheLogger::RecordDataTransfer(int64 step_id, int64 start_usecs, | |||||||
|                                            const string& tensor_name, |                                            const string& tensor_name, | ||||||
|                                            const string& src_device, |                                            const string& src_device, | ||||||
|                                            const string& dst_device, |                                            const string& dst_device, | ||||||
|                                            int64 bytes, |                                            int64 bytes, const string& details, | ||||||
|                                            const string& details, |                                            const string& transfer_method_name) { | ||||||
|                                            const string& transfer_method_name){ |  | ||||||
|   NodeExecStats* ns = new NodeExecStats; |   NodeExecStats* ns = new NodeExecStats; | ||||||
|   ns->set_node_name(transfer_method_name); |   ns->set_node_name(transfer_method_name); | ||||||
|   if (details.empty()) { |   if (details.empty()) { | ||||||
|  | |||||||
| @ -158,8 +158,8 @@ void CostModel::SetNumOutputs(const Node* node, int num_outputs) { | |||||||
|   Ensure(id, 0); |   Ensure(id, 0); | ||||||
|   auto perslot = &slot_bytes_[id]; |   auto perslot = &slot_bytes_[id]; | ||||||
|   if (!perslot->empty()) { |   if (!perslot->empty()) { | ||||||
|     CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node=" |     CHECK_EQ(num_outputs, perslot->size()) | ||||||
|                                            << node->name(); |         << "Cannot resize slot_bytes, node=" << node->name(); | ||||||
|   } |   } | ||||||
|   Ensure(id, num_outputs); |   Ensure(id, num_outputs); | ||||||
| } | } | ||||||
|  | |||||||
| @ -198,7 +198,7 @@ class CostModel { | |||||||
|   // Cumulative execution time.
 |   // Cumulative execution time.
 | ||||||
|   std::vector<Microseconds> time_; |   std::vector<Microseconds> time_; | ||||||
|   // Cumulative Bytes output on each channel.
 |   // Cumulative Bytes output on each channel.
 | ||||||
|   std::vector<gtl::InlinedVector<Bytes, 2> > slot_bytes_; |   std::vector<gtl::InlinedVector<Bytes, 2>> slot_bytes_; | ||||||
| 
 | 
 | ||||||
|   // Maximum execution time
 |   // Maximum execution time
 | ||||||
|   std::vector<Microseconds> max_exec_time_; |   std::vector<Microseconds> max_exec_time_; | ||||||
| @ -217,7 +217,7 @@ class CostModel { | |||||||
|   }; |   }; | ||||||
|   std::vector<MemUsage> max_mem_usage_; |   std::vector<MemUsage> max_mem_usage_; | ||||||
| 
 | 
 | ||||||
|   std::vector<gtl::InlinedVector<int64, 2> > output_port_alloc_ids_; |   std::vector<gtl::InlinedVector<int64, 2>> output_port_alloc_ids_; | ||||||
| 
 | 
 | ||||||
|   std::set<int64> persistent_alloc_ids_; |   std::set<int64> persistent_alloc_ids_; | ||||||
|   std::map<string, std::set<int64>> persistent_alloc_ids_by_devices_; |   std::map<string, std::set<int64>> persistent_alloc_ids_by_devices_; | ||||||
|  | |||||||
| @ -62,8 +62,8 @@ class Node; | |||||||
| class VersionDef; | class VersionDef; | ||||||
| class WhileContext; | class WhileContext; | ||||||
| 
 | 
 | ||||||
| class NeighborIter;  // Declared below
 | class NeighborIter;    // Declared below
 | ||||||
| class NodeIter;      // Declared below
 | class NodeIter;        // Declared below
 | ||||||
| class NodeProperties;  // Defined in .cc
 | class NodeProperties;  // Defined in .cc
 | ||||||
| 
 | 
 | ||||||
| class Node { | class Node { | ||||||
|  | |||||||
| @ -26,7 +26,6 @@ namespace tensorflow { | |||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| TEST(GraphDefBuilderTest, Version) { | TEST(GraphDefBuilderTest, Version) { | ||||||
| 
 |  | ||||||
|   // Verify that our assertions will be nontrivial
 |   // Verify that our assertions will be nontrivial
 | ||||||
|   ASSERT_LT(0, TF_GRAPH_DEF_VERSION); |   ASSERT_LT(0, TF_GRAPH_DEF_VERSION); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -21,102 +21,101 @@ limitations under the License. | |||||||
| #include "tensorflow/core/framework/op_kernel.h" | #include "tensorflow/core/framework/op_kernel.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
|   // Since our ops are going to produce and also consume N addition tensors
 | // Since our ops are going to produce and also consume N addition tensors
 | ||||||
|   // (Mkl) for N Tensorflow tensors, we can have following different
 | // (Mkl) for N Tensorflow tensors, we can have following different
 | ||||||
|   // orderings among these 2N tensors.
 | // orderings among these 2N tensors.
 | ||||||
|   //
 | //
 | ||||||
|   // E.g., for Tensorflow tensors A, B, and C, our ops will produce and
 | // E.g., for Tensorflow tensors A, B, and C, our ops will produce and
 | ||||||
|   // consume A_m, B_m, and C_m additionally.
 | // consume A_m, B_m, and C_m additionally.
 | ||||||
|   //
 | //
 | ||||||
|   // INTERLEAVED: in this case 2N tensors are interleaved. So for above
 | // INTERLEAVED: in this case 2N tensors are interleaved. So for above
 | ||||||
|   //              example, the ordering looks like: A, A_m, B, B_m, C, C_m.
 | //              example, the ordering looks like: A, A_m, B, B_m, C, C_m.
 | ||||||
|   //
 | //
 | ||||||
|   // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
 | // CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
 | ||||||
|   //             by N Mkl tensors. So for above example, the ordering looks
 | //             by N Mkl tensors. So for above example, the ordering looks
 | ||||||
|   //             like: A, B, C, A_m, B_m, C_m
 | //             like: A, B, C, A_m, B_m, C_m
 | ||||||
|   //
 | //
 | ||||||
|   // Following APIs map index of original Tensorflow tensors to their
 | // Following APIs map index of original Tensorflow tensors to their
 | ||||||
|   // appropriate position based on selected ordering. For contiguous ordering,
 | // appropriate position based on selected ordering. For contiguous ordering,
 | ||||||
|   // we need to know the total number of tensors (parameter total).
 | // we need to know the total number of tensors (parameter total).
 | ||||||
|   //
 | //
 | ||||||
|   typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; | typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; | ||||||
|   // NOTE: Currently, we use contiguous ordering. If you change this, then you
 | // NOTE: Currently, we use contiguous ordering. If you change this, then you
 | ||||||
|   // would need to change Mkl op definitions in nn_ops.cc.
 | // would need to change Mkl op definitions in nn_ops.cc.
 | ||||||
|   static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; | static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; | ||||||
| 
 | 
 | ||||||
|   // Get index of MetaData tensor from index 'n' of Data tensor.
 | // Get index of MetaData tensor from index 'n' of Data tensor.
 | ||||||
|   inline int DataIndexToMetaDataIndex(int n, int total_tensors) { | inline int DataIndexToMetaDataIndex(int n, int total_tensors) { | ||||||
|     if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { |   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { | ||||||
|       // For interleaved ordering, Mkl tensor follows immediately after
 |     // For interleaved ordering, Mkl tensor follows immediately after
 | ||||||
|       // Tensorflow tensor.
 |     // Tensorflow tensor.
 | ||||||
|       return n + 1; |     return n + 1; | ||||||
|     } else { |   } else { | ||||||
|       CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); | ||||||
|       // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
 |     // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
 | ||||||
|       return n + total_tensors / 2; |     return n + total_tensors / 2; | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
|   int inline GetTensorDataIndex(int n, int total_tensors) { | int inline GetTensorDataIndex(int n, int total_tensors) { | ||||||
|       if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { |   if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { | ||||||
|         return 2 * n;  // index corresponding to nth input/output tensor
 |     return 2 * n;  // index corresponding to nth input/output tensor
 | ||||||
|       } else { |   } else { | ||||||
|         CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |     CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); | ||||||
|         return n; |     return n; | ||||||
|       } |   } | ||||||
|     } | } | ||||||
| 
 | 
 | ||||||
|   int inline GetTensorMetaDataIndex(int n, int total_tensors) { | int inline GetTensorMetaDataIndex(int n, int total_tensors) { | ||||||
|       // Get index for TensorData first and then use mapping function
 |   // Get index for TensorData first and then use mapping function
 | ||||||
|       // to get TensorMetaData index from TensorData index.
 |   // to get TensorMetaData index from TensorData index.
 | ||||||
|       int tidx = GetTensorDataIndex(n, total_tensors); |   int tidx = GetTensorDataIndex(n, total_tensors); | ||||||
|       return DataIndexToMetaDataIndex(tidx, total_tensors); |   return DataIndexToMetaDataIndex(tidx, total_tensors); | ||||||
|     } | } | ||||||
| 
 | 
 | ||||||
| namespace mkl_op_registry { | namespace mkl_op_registry { | ||||||
|   static const char* kMklOpLabel = "MklOp"; | static const char* kMklOpLabel = "MklOp"; | ||||||
|   static const char* kMklOpLabelPattern = "label='MklOp'"; | static const char* kMklOpLabelPattern = "label='MklOp'"; | ||||||
|   // Prefix that we add to Tensorflow op name to construct Mkl op name.
 | // Prefix that we add to Tensorflow op name to construct Mkl op name.
 | ||||||
|   static const char* const kMklOpPrefix = "_Mkl"; | static const char* const kMklOpPrefix = "_Mkl"; | ||||||
| 
 | 
 | ||||||
|   // Get the name of Mkl op from original TensorFlow op
 | // Get the name of Mkl op from original TensorFlow op
 | ||||||
|   // We prefix 'Mkl' to the original op to get Mkl op.
 | // We prefix 'Mkl' to the original op to get Mkl op.
 | ||||||
|   inline string GetMklOpName(const string& name) { | inline string GetMklOpName(const string& name) { | ||||||
|     return string(kMklOpPrefix) + name; |   return string(kMklOpPrefix) + name; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Check whether opname with type T is registered as MKL-compliant.
 | ||||||
|  | //
 | ||||||
|  | // @input: name of the op
 | ||||||
|  | // @input: T datatype to be used for checking op
 | ||||||
|  | // @return: true if opname is registered as Mkl op; false otherwise
 | ||||||
|  | static inline bool IsMklOp(const std::string& op_name, DataType T) { | ||||||
|  |   string kernel = KernelsRegisteredForOp(op_name); | ||||||
|  |   bool result = | ||||||
|  |       kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); | ||||||
|  |   return result; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Check whether opname with type T is registered as MKL-compliant and
 | ||||||
|  | // is element-wise.
 | ||||||
|  | //
 | ||||||
|  | // @input: name of the op
 | ||||||
|  | // @input: T datatype to be used for checking op
 | ||||||
|  | // @return: true if opname is registered as element-wise Mkl op;
 | ||||||
|  | // false otherwise
 | ||||||
|  | static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) { | ||||||
|  |   if (!IsMklOp(op_name, T)) { | ||||||
|  |     return false; | ||||||
|   } |   } | ||||||
|  |   bool result = (0 == op_name.compare(GetMklOpName("Add")) || | ||||||
|  |                  0 == op_name.compare(GetMklOpName("Sub")) || | ||||||
|  |                  0 == op_name.compare(GetMklOpName("Mul")) || | ||||||
|  |                  0 == op_name.compare(GetMklOpName("Maximum")) || | ||||||
|  |                  0 == op_name.compare(GetMklOpName("SquaredDifference"))); | ||||||
| 
 | 
 | ||||||
|   // Check whether opname with type T is registered as MKL-compliant.
 |   return result; | ||||||
|   //
 | } | ||||||
|   // @input: name of the op
 |  | ||||||
|   // @input: T datatype to be used for checking op
 |  | ||||||
|   // @return: true if opname is registered as Mkl op; false otherwise
 |  | ||||||
|   static inline bool IsMklOp(const std::string& op_name, DataType T) { |  | ||||||
|     string kernel = KernelsRegisteredForOp(op_name); |  | ||||||
|     bool result = |  | ||||||
|         kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); |  | ||||||
|     return result; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   // Check whether opname with type T is registered as MKL-compliant and
 |  | ||||||
|   // is element-wise.
 |  | ||||||
|   //
 |  | ||||||
|   // @input: name of the op
 |  | ||||||
|   // @input: T datatype to be used for checking op
 |  | ||||||
|   // @return: true if opname is registered as element-wise Mkl op;
 |  | ||||||
|   // false otherwise
 |  | ||||||
|   static inline bool IsMklElementWiseOp(const std::string& op_name, |  | ||||||
|     DataType T) { |  | ||||||
|     if (!IsMklOp(op_name, T)) { |  | ||||||
|       return false; |  | ||||||
|     } |  | ||||||
|     bool result = (0 == op_name.compare(GetMklOpName("Add")) || |  | ||||||
|                     0 == op_name.compare(GetMklOpName("Sub")) || |  | ||||||
|                     0 == op_name.compare(GetMklOpName("Mul")) || |  | ||||||
|                     0 == op_name.compare(GetMklOpName("Maximum")) || |  | ||||||
|                     0 == op_name.compare(GetMklOpName("SquaredDifference"))); |  | ||||||
| 
 |  | ||||||
|     return result; |  | ||||||
|   } |  | ||||||
| }  // namespace mkl_op_registry
 | }  // namespace mkl_op_registry
 | ||||||
| }  // namespace tensorflow
 | }  // namespace tensorflow
 | ||||||
| #endif  // INTEL_MKL
 | #endif  // INTEL_MKL
 | ||||||
|  | |||||||
| @ -37,8 +37,8 @@ limitations under the License. | |||||||
| #include "tensorflow/core/platform/logging.h" | #include "tensorflow/core/platform/logging.h" | ||||||
| #include "tensorflow/core/util/tensor_format.h" | #include "tensorflow/core/util/tensor_format.h" | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/graph/mkl_layout_pass.h" |  | ||||||
| #include "tensorflow/core/graph/mkl_graph_util.h" | #include "tensorflow/core/graph/mkl_graph_util.h" | ||||||
|  | #include "tensorflow/core/graph/mkl_layout_pass.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| 
 | 
 | ||||||
| @ -281,7 +281,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; |     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; | ||||||
|     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; |     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; | ||||||
|     csinfo_.mkl_conv2d_with_bias_backprop_bias = |     csinfo_.mkl_conv2d_with_bias_backprop_bias = | ||||||
|                                    "_MklConv2DWithBiasBackpropBias"; |         "_MklConv2DWithBiasBackpropBias"; | ||||||
|     csinfo_.relu = "Relu"; |     csinfo_.relu = "Relu"; | ||||||
|     csinfo_.relu_grad = "ReluGrad"; |     csinfo_.relu_grad = "ReluGrad"; | ||||||
|     csinfo_.reshape = "Reshape"; |     csinfo_.reshape = "Reshape"; | ||||||
| @ -297,10 +297,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     // End - element-wise ops. See note above.
 |     // End - element-wise ops. See note above.
 | ||||||
| 
 | 
 | ||||||
|     // NOTE: names are alphabetically sorted.
 |     // NOTE: names are alphabetically sorted.
 | ||||||
|     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, |     rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), | ||||||
|                       AddNRewrite, nullptr}); |                       CopyAttrsAddN, AddNRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.add, |     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.add), |  | ||||||
|                       CopyAttrsDataType, AlwaysRewrite, nullptr}); |                       CopyAttrsDataType, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.avg_pool, |     rinfo_.push_back({csinfo_.avg_pool, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool), |                       mkl_op_registry::GetMklOpName(csinfo_.avg_pool), | ||||||
| @ -337,14 +336,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     rinfo_.push_back({csinfo_.fused_batch_norm, |     rinfo_.push_back({csinfo_.fused_batch_norm, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), |                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), | ||||||
|                       CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); |                       CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.fused_batch_norm_grad, |     rinfo_.push_back( | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), |         {csinfo_.fused_batch_norm_grad, | ||||||
|                       CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); |          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), | ||||||
|  |          CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.identity, |     rinfo_.push_back({csinfo_.identity, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.identity), |                       mkl_op_registry::GetMklOpName(csinfo_.identity), | ||||||
|                       CopyAttrsIdentity, AlwaysRewrite, nullptr}); |                       CopyAttrsIdentity, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.lrn, |     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.lrn), |  | ||||||
|                       CopyAttrsLRN, AlwaysRewrite, nullptr}); |                       CopyAttrsLRN, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.lrn_grad, |     rinfo_.push_back({csinfo_.lrn_grad, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), |                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), | ||||||
| @ -358,11 +357,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     rinfo_.push_back({csinfo_.maximum, |     rinfo_.push_back({csinfo_.maximum, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.maximum), |                       mkl_op_registry::GetMklOpName(csinfo_.maximum), | ||||||
|                       CopyAttrsDataType, AlwaysRewrite, nullptr}); |                       CopyAttrsDataType, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.mul, |     rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.mul), |  | ||||||
|                       CopyAttrsDataType, AlwaysRewrite, nullptr}); |                       CopyAttrsDataType, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.relu, |     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.relu), |  | ||||||
|                       CopyAttrsDataType, AlwaysRewrite, nullptr}); |                       CopyAttrsDataType, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.relu_grad, |     rinfo_.push_back({csinfo_.relu_grad, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad), |                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad), | ||||||
| @ -373,8 +370,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     rinfo_.push_back({csinfo_.squared_difference, |     rinfo_.push_back({csinfo_.squared_difference, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference), |                       mkl_op_registry::GetMklOpName(csinfo_.squared_difference), | ||||||
|                       CopyAttrsDataType, AlwaysRewrite, nullptr}); |                       CopyAttrsDataType, AlwaysRewrite, nullptr}); | ||||||
|     rinfo_.push_back({csinfo_.sub, |     rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.sub), |  | ||||||
|                       CopyAttrsDataType, AlwaysRewrite, nullptr}); |                       CopyAttrsDataType, AlwaysRewrite, nullptr}); | ||||||
| 
 | 
 | ||||||
|     // Add info about which ops to add workspace edge to and the slots.
 |     // Add info about which ops to add workspace edge to and the slots.
 | ||||||
| @ -388,9 +384,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, |     biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, | ||||||
|                                    IsBiasAddGradInMatMulContext}; |                                    IsBiasAddGradInMatMulContext}; | ||||||
| 
 | 
 | ||||||
|     biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad, |     biasaddgrad_conv2dwithbias_context_ = { | ||||||
|                                    csinfo_.mkl_conv2d_with_bias, |         csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, | ||||||
|                                    IsBiasAddGradInConv2DWithBiasContext}; |         IsBiasAddGradInConv2DWithBiasContext}; | ||||||
| 
 | 
 | ||||||
|     cinfo_.push_back(&biasaddgrad_matmul_context_); |     cinfo_.push_back(&biasaddgrad_matmul_context_); | ||||||
|     cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); |     cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); | ||||||
| @ -410,9 +406,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
| 
 | 
 | ||||||
|   /// Structure to specify the context information used in a node rewrite rule
 |   /// Structure to specify the context information used in a node rewrite rule
 | ||||||
|   typedef struct { |   typedef struct { | ||||||
|     string node;     // Name of the node to be rewritten
 |     string node;  // Name of the node to be rewritten
 | ||||||
|     string fwd;      // Name of the node in the forward pass that this node
 |     string fwd;   // Name of the node in the forward pass that this node
 | ||||||
|                      // corresponds to
 |                   // corresponds to
 | ||||||
|     std::function<bool(const Node*, const Node**, void* c)> context_match_fn; |     std::function<bool(const Node*, const Node**, void* c)> context_match_fn; | ||||||
|   } ContextInfo; |   } ContextInfo; | ||||||
| 
 | 
 | ||||||
| @ -615,14 +611,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     std::vector<int32> ksize, strides; |     std::vector<int32> ksize, strides; | ||||||
|     CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); |     CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); | ||||||
|     CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); |     CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); | ||||||
|     CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), |     CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); | ||||||
|              true); |  | ||||||
|     CHECK_EQ(FormatFromString(data_format_str, &data_format), true); |     CHECK_EQ(FormatFromString(data_format_str, &data_format), true); | ||||||
| 
 | 
 | ||||||
|     // Condition that specifies non-batch-wise and non-depth-wise pooling.
 |     // Condition that specifies non-batch-wise and non-depth-wise pooling.
 | ||||||
|     if (GetTensorDim(ksize,   data_format, 'N') == 1 && |     if (GetTensorDim(ksize, data_format, 'N') == 1 && | ||||||
|         GetTensorDim(strides, data_format, 'N') == 1 && |         GetTensorDim(strides, data_format, 'N') == 1 && | ||||||
|         GetTensorDim(ksize,   data_format, 'C') == 1 && |         GetTensorDim(ksize, data_format, 'C') == 1 && | ||||||
|         GetTensorDim(strides, data_format, 'C') == 1) { |         GetTensorDim(strides, data_format, 'C') == 1) { | ||||||
|       return true; |       return true; | ||||||
|     } |     } | ||||||
| @ -785,8 +780,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|       for (const Edge* fe : first_inp_of_filter->out_edges()) { |       for (const Edge* fe : first_inp_of_filter->out_edges()) { | ||||||
|         if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && |         if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && | ||||||
|             fe->dst_input() == 0) { |             fe->dst_input() == 0) { | ||||||
|           VLOG(1) << "MklLayoutRewritePass: found " |           VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString() | ||||||
|                   << fe->dst()->DebugString() |  | ||||||
|                   << " as the forward node for matching context, backward" |                   << " as the forward node for matching context, backward" | ||||||
|                   << " node is: " << n->DebugString(); |                   << " node is: " << n->DebugString(); | ||||||
|           *fwd_node = fe->dst(); |           *fwd_node = fe->dst(); | ||||||
| @ -803,13 +797,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|   //
 |   //
 | ||||||
|   // @return - true (if BiasAddGrad is associated with MatMul);
 |   // @return - true (if BiasAddGrad is associated with MatMul);
 | ||||||
|   //           false otherwise.
 |   //           false otherwise.
 | ||||||
|   static bool IsBiasAddGradInMatMulContext(const Node* n, |   static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node, | ||||||
|                                            const Node** fwd_node, |  | ||||||
|                                            void* ci) { |                                            void* ci) { | ||||||
|     return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); |     return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|   // Rewrite rule that uses context-information for matching,
 |   // Rewrite rule that uses context-information for matching,
 | ||||||
|   // used in scenario 2.
 |   // used in scenario 2.
 | ||||||
|   //
 |   //
 | ||||||
| @ -880,10 +872,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|   // @output output_nodes - the list of new nodes creating Mkl tensors
 |   // @output output_nodes - the list of new nodes creating Mkl tensors
 | ||||||
|   //
 |   //
 | ||||||
|   // @return None
 |   // @return None
 | ||||||
|   void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g, |   void GetNodesProducingMklTensorList( | ||||||
|     Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |       std::unique_ptr<Graph>* g, Node* orig_node, | ||||||
|     int* input_idx, int list_length, |       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, | ||||||
|     std::vector<NodeBuilder::NodeOut>* output_nodes); |       int* input_idx, int list_length, | ||||||
|  |       std::vector<NodeBuilder::NodeOut>* output_nodes); | ||||||
| 
 | 
 | ||||||
|   // Get a node that will feed an Mkl tensor to the new
 |   // Get a node that will feed an Mkl tensor to the new
 | ||||||
|   // node that we are constructing. The output node could be (1) 'n'
 |   // node that we are constructing. The output node could be (1) 'n'
 | ||||||
| @ -900,7 +893,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|   //                                will feed the tensor
 |   //                                will feed the tensor
 | ||||||
|   // @return None
 |   // @return None
 | ||||||
|   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, |   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, | ||||||
|     Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); |                                  Node* n, int n_output_slot, Node** mkl_node, | ||||||
|  |                                  int* mkl_node_output_slot); | ||||||
| 
 | 
 | ||||||
|   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
 |   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
 | ||||||
|   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
 |   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
 | ||||||
| @ -970,9 +964,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
| 
 | 
 | ||||||
| MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; | MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; | ||||||
| MklLayoutRewritePass::ContextInfo | MklLayoutRewritePass::ContextInfo | ||||||
|   MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; |     MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; | ||||||
| MklLayoutRewritePass::ContextInfo | MklLayoutRewritePass::ContextInfo | ||||||
|   MklLayoutRewritePass::biasaddgrad_matmul_context_; |     MklLayoutRewritePass::biasaddgrad_matmul_context_; | ||||||
| std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; | std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; | ||||||
| 
 | 
 | ||||||
| // We register Mkl rewrite pass for phase 1 in post partitioning group.
 | // We register Mkl rewrite pass for phase 1 in post partitioning group.
 | ||||||
| @ -1041,13 +1035,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, | |||||||
|   TensorShape dummy_shape({8}); |   TensorShape dummy_shape({8}); | ||||||
|   dummy_shape.AsProto(proto.mutable_tensor_shape()); |   dummy_shape.AsProto(proto.mutable_tensor_shape()); | ||||||
|   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") |   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") | ||||||
|                .Attr("value", proto) |                   .Attr("value", proto) | ||||||
|                .Attr("dtype", dt) |                   .Attr("dtype", dt) | ||||||
|                .Device(orig_node->def().device())  // We place this node on
 |                   .Device(orig_node->def().device())  // We place this node on
 | ||||||
|                                                    // the same device as the
 |                                                       // the same device as the
 | ||||||
|                                                    // device of the original
 |                                                       // device of the original
 | ||||||
|                                                    // node.
 |                                                       // node.
 | ||||||
|                .Finalize(&**g, out)); |                   .Finalize(&**g, out)); | ||||||
| 
 | 
 | ||||||
|   // If number of inputs to the original node is > 0, then we add
 |   // If number of inputs to the original node is > 0, then we add
 | ||||||
|   // control dependency between 1st input (index 0) of the original node and
 |   // control dependency between 1st input (index 0) of the original node and
 | ||||||
| @ -1060,8 +1054,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, | |||||||
|   // the same frame.
 |   // the same frame.
 | ||||||
|   if (orig_node->num_inputs() > 0) { |   if (orig_node->num_inputs() > 0) { | ||||||
|     Node* orig_input0 = nullptr; |     Node* orig_input0 = nullptr; | ||||||
|     TF_CHECK_OK(orig_node->input_node(0, |     TF_CHECK_OK( | ||||||
|                                       const_cast<const Node**>(&orig_input0))); |         orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); | ||||||
|     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); |     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -1069,11 +1063,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void MklLayoutRewritePass::GetNodesProducingMklTensorList( | void MklLayoutRewritePass::GetNodesProducingMklTensorList( | ||||||
|     std::unique_ptr<Graph>* g, |     std::unique_ptr<Graph>* g, Node* orig_node, | ||||||
|     Node* orig_node, |     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, | ||||||
|     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { | ||||||
|     int* input_idx, int list_length, |  | ||||||
|     std::vector<NodeBuilder::NodeOut>* output_nodes) { |  | ||||||
|   CHECK_LT(*input_idx, inputs.size()); |   CHECK_LT(*input_idx, inputs.size()); | ||||||
|   CHECK_GT(list_length, 0); |   CHECK_GT(list_length, 0); | ||||||
|   CHECK_NOTNULL(output_nodes); |   CHECK_NOTNULL(output_nodes); | ||||||
| @ -1090,8 +1082,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( | |||||||
|     int mkl_node_output_slot = 0; |     int mkl_node_output_slot = 0; | ||||||
|     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, |     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, | ||||||
|                               &mkl_node_output_slot); |                               &mkl_node_output_slot); | ||||||
|     output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, |     output_nodes->push_back( | ||||||
|                                                 mkl_node_output_slot)); |         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); | ||||||
|     (*input_idx)++; |     (*input_idx)++; | ||||||
|     list_length--; |     list_length--; | ||||||
|   } |   } | ||||||
| @ -1101,9 +1093,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( | |||||||
| // node that we are constructing. An input node could be (1) 'n'
 | // node that we are constructing. An input node could be (1) 'n'
 | ||||||
| // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
 | // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
 | ||||||
| // if 'n' is not an Mkl layer.
 | // if 'n' is not an Mkl layer.
 | ||||||
| void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, | void MklLayoutRewritePass::GetNodeProducingMklTensor( | ||||||
|     Node* orig_node, Node* n, |     std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot, | ||||||
|     int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { |     Node** mkl_node, int* mkl_node_output_slot) { | ||||||
|   CHECK_NOTNULL(n); |   CHECK_NOTNULL(n); | ||||||
|   CHECK_NOTNULL(mkl_node); |   CHECK_NOTNULL(mkl_node); | ||||||
|   CHECK_NOTNULL(mkl_node_output_slot); |   CHECK_NOTNULL(mkl_node_output_slot); | ||||||
| @ -1234,8 +1226,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs( | |||||||
|     if (ArgIsList(arg)) { |     if (ArgIsList(arg)) { | ||||||
|       std::vector<NodeBuilder::NodeOut> new_node_inputs; |       std::vector<NodeBuilder::NodeOut> new_node_inputs; | ||||||
|       int N = GetTensorListLength(arg, old_node); |       int N = GetTensorListLength(arg, old_node); | ||||||
|       GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, |       GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, | ||||||
|                                      N, &new_node_inputs); |                                      &new_node_inputs); | ||||||
|       nb->Input(new_node_inputs); |       nb->Input(new_node_inputs); | ||||||
|       nn_slot_idx++; |       nn_slot_idx++; | ||||||
|     } else { |     } else { | ||||||
| @ -1336,13 +1328,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( | |||||||
|   TensorShape dummy_shape({1}); |   TensorShape dummy_shape({1}); | ||||||
|   dummy_shape.AsProto(proto.mutable_tensor_shape()); |   dummy_shape.AsProto(proto.mutable_tensor_shape()); | ||||||
|   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") |   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") | ||||||
|                 .Attr("value", proto) |                   .Attr("value", proto) | ||||||
|                 .Attr("dtype", dt) |                   .Attr("dtype", dt) | ||||||
|                 .Device(orig_node->def().device())  // We place this node on
 |                   .Device(orig_node->def().device())  // We place this node on
 | ||||||
|                                                     // same the device as the
 |                                                       // same the device as the
 | ||||||
|                                                     // device of the original
 |                                                       // device of the original
 | ||||||
|                                                     // node.
 |                                                       // node.
 | ||||||
|                 .Finalize(&**g, out)); |                   .Finalize(&**g, out)); | ||||||
| 
 | 
 | ||||||
|   // If number of inputs to the original node is > 0, then we add
 |   // If number of inputs to the original node is > 0, then we add
 | ||||||
|   // control dependency between 1st input (index 0) of the original node and
 |   // control dependency between 1st input (index 0) of the original node and
 | ||||||
| @ -1355,8 +1347,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( | |||||||
|   // the same frame.
 |   // the same frame.
 | ||||||
|   if (orig_node->num_inputs() > 0) { |   if (orig_node->num_inputs() > 0) { | ||||||
|     Node* orig_input0 = nullptr; |     Node* orig_input0 = nullptr; | ||||||
|     TF_CHECK_OK(orig_node->input_node(0, |     TF_CHECK_OK( | ||||||
|                                       const_cast<const Node**>(&orig_input0))); |         orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); | ||||||
|     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); |     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -1374,7 +1366,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( | |||||||
|   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); | ||||||
|   for (auto ws : wsinfo_) { |   for (auto ws : wsinfo_) { | ||||||
|     if (orig_node->type_string() == ws.fwd_op && |     if (orig_node->type_string() == ws.fwd_op && | ||||||
|         mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { |         mkl_op_registry::IsMklOp( | ||||||
|  |             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { | ||||||
|       // If this op is a fwd op, then we need to check if there is an
 |       // If this op is a fwd op, then we need to check if there is an
 | ||||||
|       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
 |       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
 | ||||||
|       // an edge, then we just add an attribute on this node for setting
 |       // an edge, then we just add an attribute on this node for setting
 | ||||||
| @ -1400,8 +1393,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( | |||||||
|         nb->Attr("workspace_enabled", false); |         nb->Attr("workspace_enabled", false); | ||||||
|       } |       } | ||||||
|     } else if (orig_node->type_string() == ws.bwd_op && |     } else if (orig_node->type_string() == ws.bwd_op && | ||||||
|                mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), |                mkl_op_registry::IsMklOp( | ||||||
|                                         T)) { |                    mkl_op_registry::GetMklOpName(orig_node->type_string()), | ||||||
|  |                    T)) { | ||||||
|       // If this op is a bwd op, then we need to add workspace edge and
 |       // If this op is a bwd op, then we need to add workspace edge and
 | ||||||
|       // it's Mkl tensor edge between its corresponding fwd op and this
 |       // it's Mkl tensor edge between its corresponding fwd op and this
 | ||||||
|       // op. Corresponding fwd op is specified in 'fwd_op' field of
 |       // op. Corresponding fwd op is specified in 'fwd_op' field of
 | ||||||
| @ -1416,7 +1410,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( | |||||||
|         if (e->src_output() == ws.fwd_slot && |         if (e->src_output() == ws.fwd_slot && | ||||||
|             // We would have rewritten the forward op, so we need to use
 |             // We would have rewritten the forward op, so we need to use
 | ||||||
|             // GetMklOpName call to get its Mkl name.
 |             // GetMklOpName call to get its Mkl name.
 | ||||||
|             e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) && |             e->src()->type_string() == | ||||||
|  |                 mkl_op_registry::GetMklOpName(ws.fwd_op) && | ||||||
|             e->dst_input() == ws.bwd_slot) { |             e->dst_input() == ws.bwd_slot) { | ||||||
|           nb->Attr("workspace_enabled", true); |           nb->Attr("workspace_enabled", true); | ||||||
|           CHECK_NOTNULL(ws_tensors); |           CHECK_NOTNULL(ws_tensors); | ||||||
| @ -1593,7 +1588,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, | void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, | ||||||
|                                            NodeBuilder* nb) { |                                             NodeBuilder* nb) { | ||||||
|   DataType T; |   DataType T; | ||||||
|   DataType Tshape; |   DataType Tshape; | ||||||
| 
 | 
 | ||||||
| @ -1869,8 +1864,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, | |||||||
|       if (e->IsControlEdge()) { |       if (e->IsControlEdge()) { | ||||||
|         CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); |         CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); | ||||||
|       } else { |       } else { | ||||||
|         CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(), |         CHECK_NOTNULL( | ||||||
|                                   e->dst_input())); |             (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input())); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -1941,9 +1936,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, | |||||||
|       // and leave BiasAddGrad as it is. But we check for this condition
 |       // and leave BiasAddGrad as it is. But we check for this condition
 | ||||||
|       // when we check for node rewrite rule. So we should not even come
 |       // when we check for node rewrite rule. So we should not even come
 | ||||||
|       // here for MatMul. So we will fail now.
 |       // here for MatMul. So we will fail now.
 | ||||||
|         return Status( |       return Status( | ||||||
|             error::Code::INVALID_ARGUMENT, |           error::Code::INVALID_ARGUMENT, | ||||||
|             "No rewrite is required for BiasAddGrad for MatMul context."); |           "No rewrite is required for BiasAddGrad for MatMul context."); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -2012,9 +2007,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, | |||||||
|     if (e->IsControlEdge()) { |     if (e->IsControlEdge()) { | ||||||
|       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); |       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); | ||||||
|     } else { |     } else { | ||||||
|       CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), |       CHECK_NOTNULL((*g)->AddEdge( | ||||||
|                             e->src()->num_outputs()), |           new_node, | ||||||
|                     e->dst(), e->dst_input())); |           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), | ||||||
|  |           e->dst(), e->dst_input())); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -2070,7 +2066,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { | |||||||
| 
 | 
 | ||||||
|   // BiasAddGrad is not an Mkl layer, so we make an exception for it.
 |   // BiasAddGrad is not an Mkl layer, so we make an exception for it.
 | ||||||
|   if (n->type_string() != csinfo_.bias_add_grad) { |   if (n->type_string() != csinfo_.bias_add_grad) { | ||||||
|     if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) { |     if (!mkl_op_registry::IsMklOp( | ||||||
|  |             mkl_op_registry::GetMklOpName(n->type_string()), T)) { | ||||||
|       return nullptr; |       return nullptr; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @ -2186,8 +2183,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { | |||||||
|   return MklLayoutRewritePass().RunPass(g); |   return MklLayoutRewritePass().RunPass(g); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Status MklLayoutRewritePass::Run( | Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { | ||||||
|   const GraphOptimizationPassOptions& options) { |  | ||||||
|   if (options.graph == nullptr && options.partition_graphs == nullptr) { |   if (options.graph == nullptr && options.partition_graphs == nullptr) { | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
| @ -2215,7 +2211,7 @@ Status MklLayoutRewritePass::Run( | |||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #else  // INTEL_MKL_DNN
 | #else   // INTEL_MKL_DNN
 | ||||||
| 
 | 
 | ||||||
| // This pass implements rewriting of graph to support following scenarios:
 | // This pass implements rewriting of graph to support following scenarios:
 | ||||||
| // (A) Merging nodes in the graph
 | // (A) Merging nodes in the graph
 | ||||||
| @ -2421,7 +2417,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; |     csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; | ||||||
|     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; |     csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; | ||||||
|     csinfo_.conv2d_grad_filter_with_bias = |     csinfo_.conv2d_grad_filter_with_bias = | ||||||
|                               "__MklDummyConv2DBackpropFilterWithBias"; |         "__MklDummyConv2DBackpropFilterWithBias"; | ||||||
|     csinfo_.fused_batch_norm = "FusedBatchNorm"; |     csinfo_.fused_batch_norm = "FusedBatchNorm"; | ||||||
|     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; |     csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; | ||||||
|     csinfo_.identity = "Identity"; |     csinfo_.identity = "Identity"; | ||||||
| @ -2435,11 +2431,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; |     csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; | ||||||
|     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; |     csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; | ||||||
|     csinfo_.mkl_conv2d_grad_filter_with_bias = |     csinfo_.mkl_conv2d_grad_filter_with_bias = | ||||||
|                                    "_MklConv2DBackpropFilterWithBias"; |         "_MklConv2DBackpropFilterWithBias"; | ||||||
|     csinfo_.relu = "Relu"; |     csinfo_.relu = "Relu"; | ||||||
|     csinfo_.relu_grad = "ReluGrad"; |     csinfo_.relu_grad = "ReluGrad"; | ||||||
|     csinfo_.tanh       = "Tanh"; |     csinfo_.tanh = "Tanh"; | ||||||
|     csinfo_.tanh_grad  = "TanhGrad"; |     csinfo_.tanh_grad = "TanhGrad"; | ||||||
|     csinfo_.reshape = "Reshape"; |     csinfo_.reshape = "Reshape"; | ||||||
|     csinfo_.softmax = "Softmax"; |     csinfo_.softmax = "Softmax"; | ||||||
|     csinfo_.split = "Split"; |     csinfo_.split = "Split"; | ||||||
| @ -2474,29 +2470,28 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     rinfo_.push_back({csinfo_.conv2d, |     rinfo_.push_back({csinfo_.conv2d, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.conv2d), |                       mkl_op_registry::GetMklOpName(csinfo_.conv2d), | ||||||
|                       CopyAttrsConv2D, AlwaysRewrite}); |                       CopyAttrsConv2D, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.conv2d_with_bias, |     rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, | ||||||
|                       csinfo_.mkl_conv2d_with_bias, |  | ||||||
|                       CopyAttrsConv2D, AlwaysRewrite}); |                       CopyAttrsConv2D, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.conv2d_grad_filter, |     rinfo_.push_back({csinfo_.conv2d_grad_filter, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), |                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), | ||||||
|                       CopyAttrsConv2D, AlwaysRewrite}); |                       CopyAttrsConv2D, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, |     rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, | ||||||
|                       csinfo_.mkl_conv2d_grad_filter_with_bias, |                       csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D, | ||||||
|                       CopyAttrsConv2D, AlwaysRewrite}); |                       AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.conv2d_grad_input, |     rinfo_.push_back({csinfo_.conv2d_grad_input, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), |                       mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), | ||||||
|                       CopyAttrsConv2D, AlwaysRewrite}); |                       CopyAttrsConv2D, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.fused_batch_norm, |     rinfo_.push_back({csinfo_.fused_batch_norm, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), |                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), | ||||||
|                       CopyAttrsFusedBatchNorm, AlwaysRewrite}); |                       CopyAttrsFusedBatchNorm, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.fused_batch_norm_grad, |     rinfo_.push_back( | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), |         {csinfo_.fused_batch_norm_grad, | ||||||
|                       CopyAttrsFusedBatchNorm, AlwaysRewrite}); |          mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), | ||||||
|  |          CopyAttrsFusedBatchNorm, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.identity, |     rinfo_.push_back({csinfo_.identity, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.identity), |                       mkl_op_registry::GetMklOpName(csinfo_.identity), | ||||||
|                       CopyAttrsDataType, AlwaysRewrite}); |                       CopyAttrsDataType, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.lrn, |     rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.lrn), |  | ||||||
|                       CopyAttrsLRN, AlwaysRewrite}); |                       CopyAttrsLRN, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.lrn_grad, |     rinfo_.push_back({csinfo_.lrn_grad, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), |                       mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), | ||||||
| @ -2515,8 +2510,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.mul), |                       mkl_op_registry::GetMklOpName(csinfo_.mul), | ||||||
|                       CopyAttrsDataType, AlwaysRewrite}); |                       CopyAttrsDataType, AlwaysRewrite}); | ||||||
|     */ |     */ | ||||||
|     rinfo_.push_back({csinfo_.relu, |     rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.relu), |  | ||||||
|                       CopyAttrsDataType, AlwaysRewrite}); |                       CopyAttrsDataType, AlwaysRewrite}); | ||||||
|     rinfo_.push_back({csinfo_.relu_grad, |     rinfo_.push_back({csinfo_.relu_grad, | ||||||
|                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad), |                       mkl_op_registry::GetMklOpName(csinfo_.relu_grad), | ||||||
| @ -2550,8 +2544,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
| 
 | 
 | ||||||
|     // Add a rule for merging nodes
 |     // Add a rule for merging nodes
 | ||||||
|     minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, |     minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, | ||||||
|                       csinfo_.conv2d_with_bias, |                       csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); | ||||||
|                       GetConv2DOrBiasAdd}); |  | ||||||
| 
 | 
 | ||||||
|     minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, |     minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, | ||||||
|                       csinfo_.conv2d_grad_filter_with_bias, |                       csinfo_.conv2d_grad_filter_with_bias, | ||||||
| @ -2846,9 +2839,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
| 
 | 
 | ||||||
|   // Default rewrite rule to be used in scenario 1 for rewrite.
 |   // Default rewrite rule to be used in scenario 1 for rewrite.
 | ||||||
|   // @return - true (since we want to always rewrite)
 |   // @return - true (since we want to always rewrite)
 | ||||||
|   static bool AlwaysRewrite(const Node* n) { |   static bool AlwaysRewrite(const Node* n) { return true; } | ||||||
|     return true; |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   // Check if we are performing pooling on depth or batch. If it is, then we
 |   // Check if we are performing pooling on depth or batch. If it is, then we
 | ||||||
|   // do not rewrite MaxPool node to Mkl version.
 |   // do not rewrite MaxPool node to Mkl version.
 | ||||||
| @ -2862,14 +2853,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|     std::vector<int32> ksize, strides; |     std::vector<int32> ksize, strides; | ||||||
|     CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); |     CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); | ||||||
|     CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); |     CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); | ||||||
|     CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), |     CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); | ||||||
|              true); |  | ||||||
|     CHECK_EQ(FormatFromString(data_format_str, &data_format), true); |     CHECK_EQ(FormatFromString(data_format_str, &data_format), true); | ||||||
| 
 | 
 | ||||||
|     // Condition that specifies non-batch-wise and non-depth-wise pooling.
 |     // Condition that specifies non-batch-wise and non-depth-wise pooling.
 | ||||||
|     if (GetTensorDim(ksize,   data_format, 'N') == 1 && |     if (GetTensorDim(ksize, data_format, 'N') == 1 && | ||||||
|         GetTensorDim(strides, data_format, 'N') == 1 && |         GetTensorDim(strides, data_format, 'N') == 1 && | ||||||
|         GetTensorDim(ksize,   data_format, 'C') == 1 && |         GetTensorDim(ksize, data_format, 'C') == 1 && | ||||||
|         GetTensorDim(strides, data_format, 'C') == 1) { |         GetTensorDim(strides, data_format, 'C') == 1) { | ||||||
|       return true; |       return true; | ||||||
|     } |     } | ||||||
| @ -2941,10 +2931,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|   // @output output_nodes - the list of new nodes creating Mkl tensors
 |   // @output output_nodes - the list of new nodes creating Mkl tensors
 | ||||||
|   //
 |   //
 | ||||||
|   // @return None
 |   // @return None
 | ||||||
|   void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g, |   void GetNodesProducingMklTensorList( | ||||||
|     Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |       std::unique_ptr<Graph>* g, Node* orig_node, | ||||||
|     int* input_idx, int list_length, |       const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, | ||||||
|     std::vector<NodeBuilder::NodeOut>* output_nodes); |       int* input_idx, int list_length, | ||||||
|  |       std::vector<NodeBuilder::NodeOut>* output_nodes); | ||||||
| 
 | 
 | ||||||
|   // Get a node that will feed an Mkl tensor to the new
 |   // Get a node that will feed an Mkl tensor to the new
 | ||||||
|   // node that we are constructing. The output node could be (1) 'n'
 |   // node that we are constructing. The output node could be (1) 'n'
 | ||||||
| @ -2961,7 +2952,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { | |||||||
|   //                                will feed the tensor
 |   //                                will feed the tensor
 | ||||||
|   // @return None
 |   // @return None
 | ||||||
|   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, |   void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, | ||||||
|     Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); |                                  Node* n, int n_output_slot, Node** mkl_node, | ||||||
|  |                                  int* mkl_node_output_slot); | ||||||
| 
 | 
 | ||||||
|   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
 |   // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
 | ||||||
|   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
 |   // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
 | ||||||
| @ -3096,13 +3088,13 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, | |||||||
|   TensorShape dummy_shape({8}); |   TensorShape dummy_shape({8}); | ||||||
|   dummy_shape.AsProto(proto.mutable_tensor_shape()); |   dummy_shape.AsProto(proto.mutable_tensor_shape()); | ||||||
|   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") |   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") | ||||||
|                .Attr("value", proto) |                   .Attr("value", proto) | ||||||
|                .Attr("dtype", dt) |                   .Attr("dtype", dt) | ||||||
|                .Device(orig_node->def().device())  // We place this node on
 |                   .Device(orig_node->def().device())  // We place this node on
 | ||||||
|                                                    // the same device as the
 |                                                       // the same device as the
 | ||||||
|                                                    // device of the original
 |                                                       // device of the original
 | ||||||
|                                                    // node.
 |                                                       // node.
 | ||||||
|                .Finalize(&**g, out)); |                   .Finalize(&**g, out)); | ||||||
| 
 | 
 | ||||||
|   // If number of inputs to the original node is > 0, then we add
 |   // If number of inputs to the original node is > 0, then we add
 | ||||||
|   // control dependency between 1st input (index 0) of the original node and
 |   // control dependency between 1st input (index 0) of the original node and
 | ||||||
| @ -3115,8 +3107,8 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, | |||||||
|   // the same frame.
 |   // the same frame.
 | ||||||
|   if (orig_node->num_inputs() > 0) { |   if (orig_node->num_inputs() > 0) { | ||||||
|     Node* orig_input0 = nullptr; |     Node* orig_input0 = nullptr; | ||||||
|     TF_CHECK_OK(orig_node->input_node(0, |     TF_CHECK_OK( | ||||||
|                                       const_cast<const Node**>(&orig_input0))); |         orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); | ||||||
|     // Allow duplicate while adding control edge as it would fail (return
 |     // Allow duplicate while adding control edge as it would fail (return
 | ||||||
|     // NULL) if we try to add duplicate edge.
 |     // NULL) if we try to add duplicate edge.
 | ||||||
|     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); |     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); | ||||||
| @ -3126,11 +3118,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void MklLayoutRewritePass::GetNodesProducingMklTensorList( | void MklLayoutRewritePass::GetNodesProducingMklTensorList( | ||||||
|     std::unique_ptr<Graph>* g, |     std::unique_ptr<Graph>* g, Node* orig_node, | ||||||
|     Node* orig_node, |     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, | ||||||
|     const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |     int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { | ||||||
|     int* input_idx, int list_length, |  | ||||||
|     std::vector<NodeBuilder::NodeOut>* output_nodes) { |  | ||||||
|   CHECK_LT(*input_idx, inputs.size()); |   CHECK_LT(*input_idx, inputs.size()); | ||||||
|   CHECK_GT(list_length, 0); |   CHECK_GT(list_length, 0); | ||||||
|   CHECK_NOTNULL(output_nodes); |   CHECK_NOTNULL(output_nodes); | ||||||
| @ -3147,8 +3137,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( | |||||||
|     int mkl_node_output_slot = 0; |     int mkl_node_output_slot = 0; | ||||||
|     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, |     GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, | ||||||
|                               &mkl_node_output_slot); |                               &mkl_node_output_slot); | ||||||
|     output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, |     output_nodes->push_back( | ||||||
|                                                 mkl_node_output_slot)); |         NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); | ||||||
|     (*input_idx)++; |     (*input_idx)++; | ||||||
|     list_length--; |     list_length--; | ||||||
|   } |   } | ||||||
| @ -3158,9 +3148,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( | |||||||
| // node that we are constructing. An input node could be (1) 'n'
 | // node that we are constructing. An input node could be (1) 'n'
 | ||||||
| // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
 | // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
 | ||||||
| // if 'n' is not an Mkl layer.
 | // if 'n' is not an Mkl layer.
 | ||||||
| void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, | void MklLayoutRewritePass::GetNodeProducingMklTensor( | ||||||
|     Node* orig_node, Node* n, |     std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot, | ||||||
|     int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { |     Node** mkl_node, int* mkl_node_output_slot) { | ||||||
|   CHECK_NOTNULL(n); |   CHECK_NOTNULL(n); | ||||||
|   CHECK_NOTNULL(mkl_node); |   CHECK_NOTNULL(mkl_node); | ||||||
|   CHECK_NOTNULL(mkl_node_output_slot); |   CHECK_NOTNULL(mkl_node_output_slot); | ||||||
| @ -3292,8 +3282,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs( | |||||||
|     if (ArgIsList(arg)) { |     if (ArgIsList(arg)) { | ||||||
|       std::vector<NodeBuilder::NodeOut> new_node_inputs; |       std::vector<NodeBuilder::NodeOut> new_node_inputs; | ||||||
|       int N = GetTensorListLength(arg, old_node); |       int N = GetTensorListLength(arg, old_node); | ||||||
|       GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, |       GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, | ||||||
|                                      N, &new_node_inputs); |                                      &new_node_inputs); | ||||||
|       nb->Input(new_node_inputs); |       nb->Input(new_node_inputs); | ||||||
|       nn_slot_idx++; |       nn_slot_idx++; | ||||||
|     } else { |     } else { | ||||||
| @ -3394,13 +3384,13 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( | |||||||
|   TensorShape dummy_shape({1}); |   TensorShape dummy_shape({1}); | ||||||
|   dummy_shape.AsProto(proto.mutable_tensor_shape()); |   dummy_shape.AsProto(proto.mutable_tensor_shape()); | ||||||
|   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") |   TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") | ||||||
|                 .Attr("value", proto) |                   .Attr("value", proto) | ||||||
|                 .Attr("dtype", dt) |                   .Attr("dtype", dt) | ||||||
|                 .Device(orig_node->def().device())  // We place this node on
 |                   .Device(orig_node->def().device())  // We place this node on
 | ||||||
|                                                     // same the device as the
 |                                                       // same the device as the
 | ||||||
|                                                     // device of the original
 |                                                       // device of the original
 | ||||||
|                                                     // node.
 |                                                       // node.
 | ||||||
|                 .Finalize(&**g, out)); |                   .Finalize(&**g, out)); | ||||||
| 
 | 
 | ||||||
|   // If number of inputs to the original node is > 0, then we add
 |   // If number of inputs to the original node is > 0, then we add
 | ||||||
|   // control dependency between 1st input (index 0) of the original node and
 |   // control dependency between 1st input (index 0) of the original node and
 | ||||||
| @ -3413,8 +3403,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( | |||||||
|   // the same frame.
 |   // the same frame.
 | ||||||
|   if (orig_node->num_inputs() > 0) { |   if (orig_node->num_inputs() > 0) { | ||||||
|     Node* orig_input0 = nullptr; |     Node* orig_input0 = nullptr; | ||||||
|     TF_CHECK_OK(orig_node->input_node(0, |     TF_CHECK_OK( | ||||||
|                                       const_cast<const Node**>(&orig_input0))); |         orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); | ||||||
|     // Allow duplicate while adding control edge as it would fail (return
 |     // Allow duplicate while adding control edge as it would fail (return
 | ||||||
|     // NULL) if we try to add duplicate edge.
 |     // NULL) if we try to add duplicate edge.
 | ||||||
|     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); |     CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); | ||||||
| @ -3434,8 +3424,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( | |||||||
|   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |   TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); | ||||||
|   for (auto ws : wsinfo_) { |   for (auto ws : wsinfo_) { | ||||||
|     if (orig_node->type_string() == ws.fwd_op && |     if (orig_node->type_string() == ws.fwd_op && | ||||||
|         mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( |         mkl_op_registry::IsMklOp( | ||||||
|           orig_node->type_string()), T)) { |             mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { | ||||||
|       // If this op is a fwd op, then we need to check if there is an
 |       // If this op is a fwd op, then we need to check if there is an
 | ||||||
|       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
 |       // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
 | ||||||
|       // an edge, then we just add an attribute on this node for setting
 |       // an edge, then we just add an attribute on this node for setting
 | ||||||
| @ -3461,8 +3451,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( | |||||||
|         nb->Attr("workspace_enabled", false); |         nb->Attr("workspace_enabled", false); | ||||||
|       } |       } | ||||||
|     } else if (orig_node->type_string() == ws.bwd_op && |     } else if (orig_node->type_string() == ws.bwd_op && | ||||||
|                mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( |                mkl_op_registry::IsMklOp( | ||||||
|                                           orig_node->type_string()), T)) { |                    mkl_op_registry::GetMklOpName(orig_node->type_string()), | ||||||
|  |                    T)) { | ||||||
|       // If this op is a bwd op, then we need to add workspace edge and
 |       // If this op is a bwd op, then we need to add workspace edge and
 | ||||||
|       // it's Mkl tensor edge between its corresponding fwd op and this
 |       // it's Mkl tensor edge between its corresponding fwd op and this
 | ||||||
|       // op. Corresponding fwd op is specified in 'fwd_op' field of
 |       // op. Corresponding fwd op is specified in 'fwd_op' field of
 | ||||||
| @ -3477,8 +3468,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( | |||||||
|         if (e->src_output() == ws.fwd_slot && |         if (e->src_output() == ws.fwd_slot && | ||||||
|             // We would have rewritten the forward op, so we need to use
 |             // We would have rewritten the forward op, so we need to use
 | ||||||
|             // GetMklOpName call to get its Mkl name.
 |             // GetMklOpName call to get its Mkl name.
 | ||||||
|             e->src()->type_string() == mkl_op_registry::GetMklOpName( |             e->src()->type_string() == | ||||||
|                                                           ws.fwd_op) && |                 mkl_op_registry::GetMklOpName(ws.fwd_op) && | ||||||
|             e->dst_input() == ws.bwd_slot) { |             e->dst_input() == ws.bwd_slot) { | ||||||
|           nb->Attr("workspace_enabled", true); |           nb->Attr("workspace_enabled", true); | ||||||
|           CHECK_NOTNULL(ws_tensors); |           CHECK_NOTNULL(ws_tensors); | ||||||
| @ -3645,7 +3636,7 @@ void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, | void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, | ||||||
|                                            NodeBuilder* nb) { |                                             NodeBuilder* nb) { | ||||||
|   DataType T; |   DataType T; | ||||||
|   DataType Tshape; |   DataType Tshape; | ||||||
| 
 | 
 | ||||||
| @ -3776,8 +3767,9 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, | |||||||
|                                                     Node* m, Node* n) { |                                                     Node* m, Node* n) { | ||||||
|   CHECK_EQ(((m->type_string() == csinfo_.bias_add && |   CHECK_EQ(((m->type_string() == csinfo_.bias_add && | ||||||
|              n->type_string() == csinfo_.conv2d)) || |              n->type_string() == csinfo_.conv2d)) || | ||||||
|            ((n->type_string() == csinfo_.bias_add && |                ((n->type_string() == csinfo_.bias_add && | ||||||
|              m->type_string() == csinfo_.conv2d)), true); |                  m->type_string() == csinfo_.conv2d)), | ||||||
|  |            true); | ||||||
| 
 | 
 | ||||||
|   // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
 |   // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
 | ||||||
|   // BiasAdd is successor node, and Conv2D predecessor node.
 |   // BiasAdd is successor node, and Conv2D predecessor node.
 | ||||||
| @ -3796,8 +3788,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, | |||||||
|   TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); |   TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); | ||||||
|   TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); |   TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); | ||||||
|   TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); |   TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); | ||||||
|   TF_CHECK_OK( |   TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); | ||||||
|       GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); |  | ||||||
|   // We check to ensure that data formats of both succ and pred are same.
 |   // We check to ensure that data formats of both succ and pred are same.
 | ||||||
|   // We expect them to be same, so we can enforce this as assert.
 |   // We expect them to be same, so we can enforce this as assert.
 | ||||||
|   // But assert can be too strict, so we enforce this as a check.
 |   // But assert can be too strict, so we enforce this as a check.
 | ||||||
| @ -3900,8 +3891,8 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, | |||||||
|       // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
 |       // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
 | ||||||
|       // output (at slot 0).
 |       // output (at slot 0).
 | ||||||
|       const int kConv2DWithBiasOutputSlot = 0; |       const int kConv2DWithBiasOutputSlot = 0; | ||||||
|       CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, |       CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(), | ||||||
|                                     e->dst(), e->dst_input())); |                                   e->dst_input())); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -3924,8 +3915,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( | |||||||
|     std::unique_ptr<Graph>* g, Node* m, Node* n) { |     std::unique_ptr<Graph>* g, Node* m, Node* n) { | ||||||
|   CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && |   CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && | ||||||
|              n->type_string() == csinfo_.conv2d_grad_filter)) || |              n->type_string() == csinfo_.conv2d_grad_filter)) || | ||||||
|            ((n->type_string() == csinfo_.bias_add_grad && |                ((n->type_string() == csinfo_.bias_add_grad && | ||||||
|              m->type_string() == csinfo_.conv2d_grad_filter)), true); |                  m->type_string() == csinfo_.conv2d_grad_filter)), | ||||||
|  |            true); | ||||||
| 
 | 
 | ||||||
|   // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
 |   // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
 | ||||||
|   Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; |   Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; | ||||||
| @ -4132,9 +4124,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, | |||||||
|       // NULL) if we try to add duplicate edge.
 |       // NULL) if we try to add duplicate edge.
 | ||||||
|       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); |       CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); | ||||||
|     } else { |     } else { | ||||||
|       CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), |       CHECK_NOTNULL((*g)->AddEdge( | ||||||
|                             e->src()->num_outputs()), |           new_node, | ||||||
|                     e->dst(), e->dst_input())); |           GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), | ||||||
|  |           e->dst(), e->dst_input())); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -4166,9 +4159,9 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { | |||||||
|   // names.
 |   // names.
 | ||||||
|   if (n->type_string() != csinfo_.conv2d_with_bias && |   if (n->type_string() != csinfo_.conv2d_with_bias && | ||||||
|       n->type_string() != csinfo_.conv2d_grad_filter_with_bias && |       n->type_string() != csinfo_.conv2d_grad_filter_with_bias && | ||||||
|       !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName( |       !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), | ||||||
|                                         n->type_string()), T)) { |                                 T)) { | ||||||
|       return nullptr; |     return nullptr; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // For elementwise node, we reuse the Eigen implementation and pass the MKL
 |   // For elementwise node, we reuse the Eigen implementation and pass the MKL
 | ||||||
| @ -4184,29 +4177,30 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { | |||||||
|   // eigen code to reduce cross-library dependency.
 |   // eigen code to reduce cross-library dependency.
 | ||||||
|   VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string(); |   VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string(); | ||||||
|   if (mkl_op_registry::IsMklElementWiseOp( |   if (mkl_op_registry::IsMklElementWiseOp( | ||||||
|         mkl_op_registry::GetMklOpName(n->type_string()), T) || |           mkl_op_registry::GetMklOpName(n->type_string()), T) || | ||||||
|       n->type_string().find("Identity") != string::npos) { |       n->type_string().find("Identity") != string::npos) { | ||||||
|     VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string(); |     VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string(); | ||||||
|     bool incoming_mkl_edge = false; |     bool incoming_mkl_edge = false; | ||||||
|     int num_parent = 0; |     int num_parent = 0; | ||||||
|     for (auto parent : n->in_edges()) { |     for (auto parent : n->in_edges()) { | ||||||
|       if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) { |       if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) { | ||||||
|         VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is MKL op: " |         VLOG(1) << "ELEMENTWISE: parent " << num_parent++ | ||||||
|                 << parent->src()->type_string(); |                 << " is MKL op: " << parent->src()->type_string(); | ||||||
|         incoming_mkl_edge = true; |         incoming_mkl_edge = true; | ||||||
|         break; |         break; | ||||||
|       } else { |       } else { | ||||||
|         VLOG(1) << "ELEMENTWISE: parent " << num_parent++ << " is NON-MKL op: " |         VLOG(1) << "ELEMENTWISE: parent " << num_parent++ | ||||||
|                 << parent->src()->type_string(); |                 << " is NON-MKL op: " << parent->src()->type_string(); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|     if (incoming_mkl_edge == false) { |     if (incoming_mkl_edge == false) { | ||||||
|       VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which has no MKL " |       VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which " | ||||||
|  |                  "has no MKL " | ||||||
|                  "parents."; |                  "parents."; | ||||||
|       return nullptr; |       return nullptr; | ||||||
|     } else { |     } else { | ||||||
|       VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() << |       VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() | ||||||
|         " which has MKL parents"; |               << " which has MKL parents"; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -4214,8 +4208,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { | |||||||
|   // for this op, then we rewrite it to Mkl op.
 |   // for this op, then we rewrite it to Mkl op.
 | ||||||
|   // Find matching RewriteInfo and then check that rewrite rule applies.
 |   // Find matching RewriteInfo and then check that rewrite rule applies.
 | ||||||
|   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { |   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { | ||||||
|     if (n->type_string().compare(ri->name) == 0 && |     if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { | ||||||
|         ri->rewrite_rule(n)) { |  | ||||||
|       return &*ri; |       return &*ri; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @ -4297,8 +4290,7 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { | |||||||
|   return MklLayoutRewritePass().RunPass(g); |   return MklLayoutRewritePass().RunPass(g); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Status MklLayoutRewritePass::Run( | Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { | ||||||
|   const GraphOptimizationPassOptions& options) { |  | ||||||
|   if (options.graph == nullptr && options.partition_graphs == nullptr) { |   if (options.graph == nullptr && options.partition_graphs == nullptr) { | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -125,8 +125,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); | |||||||
| REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); | REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); | ||||||
| REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); | REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); | ||||||
| REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); | REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); | ||||||
| REGISTER_OP("_MklInput2").Output("o: uint8") | REGISTER_OP("_MklInput2") | ||||||
|                         .Output("o1: uint8").SetIsStateful(); |     .Output("o: uint8") | ||||||
|  |     .Output("o1: uint8") | ||||||
|  |     .SetIsStateful(); | ||||||
| 
 | 
 | ||||||
| /////////////////////////////////////////////////////////////////////
 | /////////////////////////////////////////////////////////////////////
 | ||||||
| //  Unit tests related to node merge optiimization
 | //  Unit tests related to node merge optiimization
 | ||||||
| @ -498,7 +500,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) { | |||||||
|             "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); |             "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
 | // BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
 | ||||||
| TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { | TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { | ||||||
|   InitGraph( |   InitGraph( | ||||||
| @ -874,11 +875,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { | |||||||
|       " input: ['A', 'B:0', 'B:1']}" |       " input: ['A', 'B:0', 'B:1']}" | ||||||
|       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['C', 'D'] }"); |       " input: ['C', 'D'] }"); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ( | ||||||
|             "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" |       DoMklLayoutOptimizationPass(), | ||||||
|             "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" |       "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" | ||||||
|             "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" |       "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" | ||||||
|             "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); |       "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" | ||||||
|  |       "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Concat with 2 Mkl layers feeding it
 | // Concat with 2 Mkl layers feeding it
 | ||||||
| @ -1273,7 +1275,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { | |||||||
|       "node { name: 'H' op: 'Input'}" |       "node { name: 'H' op: 'Input'}" | ||||||
|       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['H', 'G'] }"); |       " input: ['H', 'G'] }"); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ( | ||||||
|  |       DoMklLayoutOptimizationPass(), | ||||||
|       "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" |       "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" | ||||||
|       "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" |       "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" | ||||||
|       "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" |       "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" | ||||||
| @ -1640,7 +1643,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { | |||||||
|       " attr { key: 'padding'          value { s: 'SAME' } }" |       " attr { key: 'padding'          value { s: 'SAME' } }" | ||||||
|       " input: ['A', 'B']}" |       " input: ['A', 'B']}" | ||||||
|       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['B', 'C'] }", kGPUDevice); |       " input: ['B', 'C'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); |             "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); | ||||||
| } | } | ||||||
| @ -1666,7 +1670,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { | |||||||
|       "node { name: 'F' op: 'BiasAddGrad'" |       "node { name: 'F' op: 'BiasAddGrad'" | ||||||
|       " attr { key: 'T'                value { type: DT_FLOAT } }" |       " attr { key: 'T'                value { type: DT_FLOAT } }" | ||||||
|       " attr { key: 'data_format'      value { s: 'NCHW' } }" |       " attr { key: 'data_format'      value { s: 'NCHW' } }" | ||||||
|       " input: ['E'] }", kGPUDevice); |       " input: ['E'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" |             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" | ||||||
|             "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" |             "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" | ||||||
| @ -1687,7 +1692,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { | |||||||
|       " attr { key: 'padding'          value { s: 'SAME' } }" |       " attr { key: 'padding'          value { s: 'SAME' } }" | ||||||
|       " input: ['A', 'B', 'C']}" |       " input: ['A', 'B', 'C']}" | ||||||
|       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'D'] }", kGPUDevice); |       " input: ['A', 'D'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" |             "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" | ||||||
|             "A->D;A->E;B->D:1;C->D:2;D->E:1"); |             "A->D;A->E;B->D:1;C->D:2;D->E:1"); | ||||||
| @ -1700,7 +1706,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { | |||||||
|       " attr { key: 'T'                value { type: DT_FLOAT } }" |       " attr { key: 'T'                value { type: DT_FLOAT } }" | ||||||
|       " input: ['A'] }" |       " input: ['A'] }" | ||||||
|       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'B'] }", kGPUDevice); |       " input: ['A', 'B'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); |             "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); | ||||||
| } | } | ||||||
| @ -1713,7 +1720,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { | |||||||
|       " attr { key: 'T'                value { type: DT_FLOAT } }" |       " attr { key: 'T'                value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'B'] }" |       " input: ['A', 'B'] }" | ||||||
|       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'C'] }", kGPUDevice); |       " input: ['A', 'C'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); |             "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); | ||||||
| } | } | ||||||
| @ -1729,7 +1737,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { | |||||||
|       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }" |       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }" | ||||||
|       " input: ['A'] }" |       " input: ['A'] }" | ||||||
|       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'B'] }", kGPUDevice); |       " input: ['A', 'B'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); |             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); | ||||||
| } | } | ||||||
| @ -1745,7 +1754,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { | |||||||
|       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }" |       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }" | ||||||
|       " input: ['A'] }" |       " input: ['A'] }" | ||||||
|       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'B'] }", kGPUDevice); |       " input: ['A', 'B'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); |             "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); | ||||||
| } | } | ||||||
| @ -1766,7 +1776,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { | |||||||
|       " attr { key: 'N'                value { i: 2 } }" |       " attr { key: 'N'                value { i: 2 } }" | ||||||
|       " input: ['A', 'B:0', 'B:1']}" |       " input: ['A', 'B:0', 'B:1']}" | ||||||
|       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['C', 'D'] }", kGPUDevice); |       " input: ['C', 'D'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" |             "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" | ||||||
|             "B->D:1;B:1->D:2;C->E;D->E:1"); |             "B->D:1;B:1->D:2;C->E;D->E:1"); | ||||||
| @ -1788,7 +1799,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { | |||||||
|       " attr { key: 'N'                value { i: 2 } }" |       " attr { key: 'N'                value { i: 2 } }" | ||||||
|       " input: ['B:0', 'B:1', 'A']}" |       " input: ['B:0', 'B:1', 'A']}" | ||||||
|       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['C', 'D'] }", kGPUDevice); |       " input: ['C', 'D'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" |             "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" | ||||||
|             "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); |             "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); | ||||||
| @ -1808,7 +1820,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { | |||||||
|       " attr { key: 'is_training'  value { b: true } }" |       " attr { key: 'is_training'  value { b: true } }" | ||||||
|       " input: ['A', 'B', 'C', 'D', 'E'] }" |       " input: ['A', 'B', 'C', 'D', 'E'] }" | ||||||
|       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'F'] }", kGPUDevice); |       " input: ['A', 'F'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(Input);D(Input);E(Input);" |             "A(Input);B(Input);C(Input);D(Input);E(Input);" | ||||||
|             "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" |             "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" | ||||||
| @ -1837,7 +1850,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { | |||||||
|       "node { name: 'Y' op: 'Input'}" |       "node { name: 'Y' op: 'Input'}" | ||||||
|       "node { name: 'Z' op: 'Zeta'" |       "node { name: 'Z' op: 'Zeta'" | ||||||
|       " attr {key: 'T'                 value { type: DT_FLOAT } }" |       " attr {key: 'T'                 value { type: DT_FLOAT } }" | ||||||
|       " input: ['E', 'Y']}", kGPUDevice); |       " input: ['E', 'Y']}", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" |             "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" | ||||||
|             "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" |             "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" | ||||||
| @ -1972,8 +1986,10 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); | |||||||
| REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); | REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); | ||||||
| REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); | REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); | ||||||
| REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); | REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); | ||||||
| REGISTER_OP("_MklInput2").Output("o: uint8") | REGISTER_OP("_MklInput2") | ||||||
|                         .Output("o1: uint8").SetIsStateful(); |     .Output("o: uint8") | ||||||
|  |     .Output("o1: uint8") | ||||||
|  |     .SetIsStateful(); | ||||||
| 
 | 
 | ||||||
| /////////////////////////////////////////////////////////////////////
 | /////////////////////////////////////////////////////////////////////
 | ||||||
| //  Unit tests related to node merge optiimization
 | //  Unit tests related to node merge optiimization
 | ||||||
| @ -2492,11 +2508,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { | |||||||
|       " input: ['A', 'B:0', 'B:1']}" |       " input: ['A', 'B:0', 'B:1']}" | ||||||
|       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['C', 'D'] }"); |       " input: ['C', 'D'] }"); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ( | ||||||
|             "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" |       DoMklLayoutOptimizationPass(), | ||||||
|             "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" |       "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" | ||||||
|             "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" |       "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" | ||||||
|             "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); |       "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" | ||||||
|  |       "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Concat with 2 Mkl layers feeding it
 | // Concat with 2 Mkl layers feeding it
 | ||||||
| @ -2891,7 +2908,8 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { | |||||||
|       "node { name: 'H' op: 'Input'}" |       "node { name: 'H' op: 'Input'}" | ||||||
|       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['H', 'G'] }"); |       " input: ['H', 'G'] }"); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ( | ||||||
|  |       DoMklLayoutOptimizationPass(), | ||||||
|       "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" |       "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" | ||||||
|       "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" |       "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" | ||||||
|       "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" |       "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" | ||||||
| @ -3258,7 +3276,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { | |||||||
|       " attr { key: 'padding'          value { s: 'SAME' } }" |       " attr { key: 'padding'          value { s: 'SAME' } }" | ||||||
|       " input: ['A', 'B']}" |       " input: ['A', 'B']}" | ||||||
|       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['B', 'C'] }", kGPUDevice); |       " input: ['B', 'C'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); |             "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); | ||||||
| } | } | ||||||
| @ -3284,7 +3303,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { | |||||||
|       "node { name: 'F' op: 'BiasAddGrad'" |       "node { name: 'F' op: 'BiasAddGrad'" | ||||||
|       " attr { key: 'T'                value { type: DT_FLOAT } }" |       " attr { key: 'T'                value { type: DT_FLOAT } }" | ||||||
|       " attr { key: 'data_format'      value { s: 'NCHW' } }" |       " attr { key: 'data_format'      value { s: 'NCHW' } }" | ||||||
|       " input: ['E'] }", kGPUDevice); |       " input: ['E'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" |             "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" | ||||||
|             "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" |             "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" | ||||||
| @ -3305,7 +3325,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { | |||||||
|       " attr { key: 'padding'          value { s: 'SAME' } }" |       " attr { key: 'padding'          value { s: 'SAME' } }" | ||||||
|       " input: ['A', 'B', 'C']}" |       " input: ['A', 'B', 'C']}" | ||||||
|       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'D'] }", kGPUDevice); |       " input: ['A', 'D'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" |             "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" | ||||||
|             "A->D;A->E;B->D:1;C->D:2;D->E:1"); |             "A->D;A->E;B->D:1;C->D:2;D->E:1"); | ||||||
| @ -3318,7 +3339,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { | |||||||
|       " attr { key: 'T'                value { type: DT_FLOAT } }" |       " attr { key: 'T'                value { type: DT_FLOAT } }" | ||||||
|       " input: ['A'] }" |       " input: ['A'] }" | ||||||
|       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'B'] }", kGPUDevice); |       " input: ['A', 'B'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); |             "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); | ||||||
| } | } | ||||||
| @ -3331,7 +3353,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { | |||||||
|       " attr { key: 'T'                value { type: DT_FLOAT } }" |       " attr { key: 'T'                value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'B'] }" |       " input: ['A', 'B'] }" | ||||||
|       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'C'] }", kGPUDevice); |       " input: ['A', 'C'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); |             "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); | ||||||
| } | } | ||||||
| @ -3347,7 +3370,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { | |||||||
|       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }" |       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }" | ||||||
|       " input: ['A'] }" |       " input: ['A'] }" | ||||||
|       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'B'] }", kGPUDevice); |       " input: ['A', 'B'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); |             "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); | ||||||
| } | } | ||||||
| @ -3363,7 +3387,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { | |||||||
|       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }" |       " attr { key: 'strides'      value { list: {i: 1, i:1, i:1, i:1} } }" | ||||||
|       " input: ['A'] }" |       " input: ['A'] }" | ||||||
|       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'B'] }", kGPUDevice); |       " input: ['A', 'B'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); |             "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); | ||||||
| } | } | ||||||
| @ -3384,7 +3409,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { | |||||||
|       " attr { key: 'N'                value { i: 2 } }" |       " attr { key: 'N'                value { i: 2 } }" | ||||||
|       " input: ['A', 'B:0', 'B:1']}" |       " input: ['A', 'B:0', 'B:1']}" | ||||||
|       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['C', 'D'] }", kGPUDevice); |       " input: ['C', 'D'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" |             "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" | ||||||
|             "B->D:1;B:1->D:2;C->E;D->E:1"); |             "B->D:1;B:1->D:2;C->E;D->E:1"); | ||||||
| @ -3406,7 +3432,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { | |||||||
|       " attr { key: 'N'                value { i: 2 } }" |       " attr { key: 'N'                value { i: 2 } }" | ||||||
|       " input: ['B:0', 'B:1', 'A']}" |       " input: ['B:0', 'B:1', 'A']}" | ||||||
|       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['C', 'D'] }", kGPUDevice); |       " input: ['C', 'D'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" |             "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" | ||||||
|             "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); |             "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); | ||||||
| @ -3426,7 +3453,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { | |||||||
|       " attr { key: 'is_training'  value { b: true } }" |       " attr { key: 'is_training'  value { b: true } }" | ||||||
|       " input: ['A', 'B', 'C', 'D', 'E'] }" |       " input: ['A', 'B', 'C', 'D', 'E'] }" | ||||||
|       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" |       "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" | ||||||
|       " input: ['A', 'F'] }", kGPUDevice); |       " input: ['A', 'F'] }", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(Input);D(Input);E(Input);" |             "A(Input);B(Input);C(Input);D(Input);E(Input);" | ||||||
|             "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" |             "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" | ||||||
| @ -3455,7 +3483,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { | |||||||
|       "node { name: 'Y' op: 'Input'}" |       "node { name: 'Y' op: 'Input'}" | ||||||
|       "node { name: 'Z' op: 'Zeta'" |       "node { name: 'Z' op: 'Zeta'" | ||||||
|       " attr {key: 'T'                 value { type: DT_FLOAT } }" |       " attr {key: 'T'                 value { type: DT_FLOAT } }" | ||||||
|       " input: ['E', 'Y']}", kGPUDevice); |       " input: ['E', 'Y']}", | ||||||
|  |       kGPUDevice); | ||||||
|   EXPECT_EQ(DoMklLayoutOptimizationPass(), |   EXPECT_EQ(DoMklLayoutOptimizationPass(), | ||||||
|             "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" |             "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" | ||||||
|             "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" |             "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" | ||||||
|  | |||||||
| @ -33,8 +33,8 @@ limitations under the License. | |||||||
| #include "tensorflow/core/lib/hash/hash.h" | #include "tensorflow/core/lib/hash/hash.h" | ||||||
| #include "tensorflow/core/platform/logging.h" | #include "tensorflow/core/platform/logging.h" | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/graph/mkl_tfconversion_pass.h" |  | ||||||
| #include "tensorflow/core/graph/mkl_graph_util.h" | #include "tensorflow/core/graph/mkl_graph_util.h" | ||||||
|  | #include "tensorflow/core/graph/mkl_tfconversion_pass.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| 
 | 
 | ||||||
| @ -152,12 +152,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge( | |||||||
|   string data_format; |   string data_format; | ||||||
| 
 | 
 | ||||||
|   TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); |   TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); | ||||||
|   bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) == |   bool dst_dtype_found = | ||||||
|                           Status::OK(); |       GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK(); | ||||||
|   // We compare source and destination datatypes only when both are found.
 |   // We compare source and destination datatypes only when both are found.
 | ||||||
|   if (dst_dtype_found && (src_datatype != dst_datatype)) { |   if (dst_dtype_found && (src_datatype != dst_datatype)) { | ||||||
|     string err_msg = "T attribute of " + src->name() + " and " + |     string err_msg = "T attribute of " + src->name() + " and " + dst->name() + | ||||||
|                       dst->name() + " do not match. Will not insert" + |                      " do not match. Will not insert" + | ||||||
|                      " MklToTf node in such case."; |                      " MklToTf node in such case."; | ||||||
|     return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); |     return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); | ||||||
|   } |   } | ||||||
| @ -325,12 +325,12 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) { | |||||||
|     // may not be Mkl node.
 |     // may not be Mkl node.
 | ||||||
|     DataType src_datatype; |     DataType src_datatype; | ||||||
|     DataType dst_datatype; |     DataType dst_datatype; | ||||||
|     bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) == |     bool src_is_mkl_op = | ||||||
|                             Status::OK() && |         (GetNodeAttr(src->def(), "T", &src_datatype) == Status::OK() && | ||||||
|                           IsMklSupportedOp(src->type_string(), src_datatype)); |          IsMklSupportedOp(src->type_string(), src_datatype)); | ||||||
|     bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) == |     bool dst_is_mkl_op = | ||||||
|                             Status::OK() && |         (GetNodeAttr(dst->def(), "T", &dst_datatype) == Status::OK() && | ||||||
|                           IsMklSupportedOp(dst->type_string(), dst_datatype)); |          IsMklSupportedOp(dst->type_string(), dst_datatype)); | ||||||
| 
 | 
 | ||||||
|     // Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
 |     // Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
 | ||||||
|     if (src_is_mkl_op && !dst_is_mkl_op) { |     if (src_is_mkl_op && !dst_is_mkl_op) { | ||||||
|  | |||||||
| @ -40,7 +40,7 @@ REGISTER_KERNEL_BUILDER( | |||||||
| #ifdef TENSORFLOW_USE_SYCL | #ifdef TENSORFLOW_USE_SYCL | ||||||
| REGISTER_KERNEL_BUILDER( | REGISTER_KERNEL_BUILDER( | ||||||
|     Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp); |     Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp); | ||||||
| #endif // TENSORFLOW_USE_SYCL
 | #endif  // TENSORFLOW_USE_SYCL
 | ||||||
| 
 | 
 | ||||||
| // Register the HostConst Op
 | // Register the HostConst Op
 | ||||||
| // Returns a constant tensor on the host.  Useful for writing C++ tests
 | // Returns a constant tensor on the host.  Useful for writing C++ tests
 | ||||||
|  | |||||||
| @ -114,7 +114,7 @@ class PeriodicFunction { | |||||||
|   void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_); |   void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_); | ||||||
| 
 | 
 | ||||||
|   const std::function<void()> function_;  // Actual client function
 |   const std::function<void()> function_;  // Actual client function
 | ||||||
|   const int64 interval_micros_;    // Interval between calls.
 |   const int64 interval_micros_;           // Interval between calls.
 | ||||||
|   const Options options_; |   const Options options_; | ||||||
| 
 | 
 | ||||||
|   // Protects state below.
 |   // Protects state below.
 | ||||||
|  | |||||||
| @ -55,15 +55,14 @@ Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) { | |||||||
| // use the clock to be destroyed.
 | // use the clock to be destroyed.
 | ||||||
| std::unique_ptr<Thread> CreateFakeClockAdvancerThread( | std::unique_ptr<Thread> CreateFakeClockAdvancerThread( | ||||||
|     test_util::FakeClockEnv* env, Notification* start, Notification* stop) { |     test_util::FakeClockEnv* env, Notification* start, Notification* stop) { | ||||||
|   return std::unique_ptr<Thread>( |   return std::unique_ptr<Thread>(Env::Default()->StartThread( | ||||||
|       Env::Default()->StartThread({}, "FakeClockAdvancerThread", |       {}, "FakeClockAdvancerThread", [env, start, stop] { | ||||||
|                                   [env, start, stop] { |         start->WaitForNotification(); | ||||||
|                                     start->WaitForNotification(); |         while (!stop->HasBeenNotified()) { | ||||||
|                                     while (!stop->HasBeenNotified()) { |           env->AdvanceByMicroseconds(10); | ||||||
|                                       env->AdvanceByMicroseconds(10); |           Env::Default()->SleepForMicroseconds(10); | ||||||
|                                       Env::Default()->SleepForMicroseconds(10); |         } | ||||||
|                                     } |       })); | ||||||
|                                   })); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST(SharedBatchSchedulerTest, Basic) { | TEST(SharedBatchSchedulerTest, Basic) { | ||||||
| @ -258,7 +257,7 @@ TEST(SharedBatchSchedulerTest, ObeysTimeout) { | |||||||
| TEST(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) { | TEST(SharedBatchSchedulerTest, ObeysTimeoutWithRealClock) { | ||||||
|   Notification first_batch_processed, second_batch_processed; |   Notification first_batch_processed, second_batch_processed; | ||||||
|   auto callback = [&first_batch_processed, &second_batch_processed]( |   auto callback = [&first_batch_processed, &second_batch_processed]( | ||||||
|       std::unique_ptr<Batch<FakeTask>> batch) { |                       std::unique_ptr<Batch<FakeTask>> batch) { | ||||||
|     ASSERT_TRUE(batch->IsClosed()); |     ASSERT_TRUE(batch->IsClosed()); | ||||||
|     if (batch->size() == 1) { |     if (batch->size() == 1) { | ||||||
|       first_batch_processed.Notify(); |       first_batch_processed.Notify(); | ||||||
| @ -301,7 +300,7 @@ TEST(SharedBatchSchedulerTest, | |||||||
|   { |   { | ||||||
|     Notification first_batch_processed, second_batch_processed; |     Notification first_batch_processed, second_batch_processed; | ||||||
|     auto callback = [&first_batch_processed, &second_batch_processed]( |     auto callback = [&first_batch_processed, &second_batch_processed]( | ||||||
|         std::unique_ptr<Batch<FakeTask>> batch) { |                         std::unique_ptr<Batch<FakeTask>> batch) { | ||||||
|       ASSERT_TRUE(batch->IsClosed()); |       ASSERT_TRUE(batch->IsClosed()); | ||||||
|       if (batch->size() == 1) { |       if (batch->size() == 1) { | ||||||
|         first_batch_processed.Notify(); |         first_batch_processed.Notify(); | ||||||
| @ -349,7 +348,7 @@ TEST(SharedBatchSchedulerTest, Fairness) { | |||||||
|     auto queue_0_callback = [&queue_0_first_batch_scheduled, |     auto queue_0_callback = [&queue_0_first_batch_scheduled, | ||||||
|                              &queue_0_first_batch_proceed, |                              &queue_0_first_batch_proceed, | ||||||
|                              &queue_0_second_batch_scheduled]( |                              &queue_0_second_batch_scheduled]( | ||||||
|         std::unique_ptr<Batch<FakeTask>> batch) { |                                 std::unique_ptr<Batch<FakeTask>> batch) { | ||||||
|       if (!queue_0_first_batch_scheduled.HasBeenNotified()) { |       if (!queue_0_first_batch_scheduled.HasBeenNotified()) { | ||||||
|         queue_0_first_batch_scheduled.Notify(); |         queue_0_first_batch_scheduled.Notify(); | ||||||
|         queue_0_first_batch_proceed.WaitForNotification(); |         queue_0_first_batch_proceed.WaitForNotification(); | ||||||
| @ -467,7 +466,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) { | |||||||
| TEST(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) { | TEST(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) { | ||||||
|   Notification queue_0_processing, queue_0_proceed; |   Notification queue_0_processing, queue_0_proceed; | ||||||
|   auto queue_0_callback = [&queue_0_processing, &queue_0_proceed]( |   auto queue_0_callback = [&queue_0_processing, &queue_0_proceed]( | ||||||
|       std::unique_ptr<Batch<FakeTask>> batch) { |                               std::unique_ptr<Batch<FakeTask>> batch) { | ||||||
|     if (!queue_0_processing.HasBeenNotified()) { |     if (!queue_0_processing.HasBeenNotified()) { | ||||||
|       queue_0_processing.Notify(); |       queue_0_processing.Notify(); | ||||||
|       queue_0_proceed.WaitForNotification(); |       queue_0_proceed.WaitForNotification(); | ||||||
|  | |||||||
| @ -92,7 +92,6 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|    private: |    private: | ||||||
| 
 |  | ||||||
|     class Iterator : public DatasetIterator<Dataset> { |     class Iterator : public DatasetIterator<Dataset> { | ||||||
|      public: |      public: | ||||||
|       explicit Iterator(const Params& params) |       explicit Iterator(const Params& params) | ||||||
|  | |||||||
| @ -22,7 +22,6 @@ limitations under the License. | |||||||
| #include "tensorflow/core/lib/random/random.h" | #include "tensorflow/core/lib/random/random.h" | ||||||
| #include "tensorflow/core/platform/notification.h" | #include "tensorflow/core/platform/notification.h" | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| 
 | 
 | ||||||
| /* static */ | /* static */ | ||||||
| @ -185,8 +184,7 @@ Status CapturedFunction::MaybeInstantiate( | |||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Status CapturedFunction::Run(IteratorContext* ctx, | Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args, | ||||||
|                              std::vector<Tensor>&& args, |  | ||||||
|                              std::vector<Tensor>* rets) { |                              std::vector<Tensor>* rets) { | ||||||
|   FunctionLibraryRuntime::Handle handle; |   FunctionLibraryRuntime::Handle handle; | ||||||
|   TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); |   TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle)); | ||||||
|  | |||||||
| @ -128,8 +128,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { | |||||||
|         while (i_ < dataset()->count_) { |         while (i_ < dataset()->count_) { | ||||||
|           // Fetch and throw away Tensors.
 |           // Fetch and throw away Tensors.
 | ||||||
|           std::vector<Tensor> dummy_out_tensors; |           std::vector<Tensor> dummy_out_tensors; | ||||||
|           TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &dummy_out_tensors, |           TF_RETURN_IF_ERROR( | ||||||
|                                                   end_of_sequence)); |               input_impl_->GetNext(ctx, &dummy_out_tensors, end_of_sequence)); | ||||||
|           if (*end_of_sequence) { |           if (*end_of_sequence) { | ||||||
|             // We reached the end before the count was reached.
 |             // We reached the end before the count was reached.
 | ||||||
|             input_impl_.reset(); |             input_impl_.reset(); | ||||||
| @ -140,8 +140,8 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // Return GetNext() on the underlying iterator.
 |         // Return GetNext() on the underlying iterator.
 | ||||||
|         TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, |         TF_RETURN_IF_ERROR( | ||||||
|                                                 end_of_sequence)); |             input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); | ||||||
|         if (*end_of_sequence) { |         if (*end_of_sequence) { | ||||||
|           input_impl_.reset(); |           input_impl_.reset(); | ||||||
|         } |         } | ||||||
| @ -184,8 +184,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { | |||||||
|   }; |   }; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), | REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp); | ||||||
|                         SkipDatasetOp); |  | ||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -13,8 +13,8 @@ See the License for the specific language governing permissions and | |||||||
| limitations under the License. | limitations under the License. | ||||||
| ==============================================================================*/ | ==============================================================================*/ | ||||||
| 
 | 
 | ||||||
| #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" |  | ||||||
| #include "tensorflow/cc/ops/standard_ops.h" | #include "tensorflow/cc/ops/standard_ops.h" | ||||||
|  | #include "tensorflow/core/kernels/fuzzing/fuzz_session.h" | ||||||
| 
 | 
 | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace fuzzing { | namespace fuzzing { | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data, | |||||||
|       GetTopNFloatResults(data, labels, element_count); |       GetTopNFloatResults(data, labels, element_count); | ||||||
|   LOG(INFO) << "=== Dump ranking ==="; |   LOG(INFO) << "=== Dump ranking ==="; | ||||||
|   for (int i = 0; i < top_n; ++i) { |   for (int i = 0; i < top_n; ++i) { | ||||||
|     const std::tuple<float, int, string> &entry = queue.top(); |     const std::tuple<float, int, string>& entry = queue.top(); | ||||||
|     LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry) |     LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry) | ||||||
|               << ", " << std::get<0>(entry); |               << ", " << std::get<0>(entry); | ||||||
|     queue.pop(); |     queue.pop(); | ||||||
|  | |||||||
| @ -181,8 +181,8 @@ class GraphTransferer { | |||||||
|   void AppendNodeInputParams(const int id, const Node& node, |   void AppendNodeInputParams(const int id, const Node& node, | ||||||
|                              const std::vector<int>& extra_inputs); |                              const std::vector<int>& extra_inputs); | ||||||
| 
 | 
 | ||||||
|   void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, |   void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id, | ||||||
|                               const int id, const Node& node); |                               const Node& node); | ||||||
| 
 | 
 | ||||||
|   static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray( |   static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray( | ||||||
|       const shape_inference::ShapeHandle& shape_handle, |       const shape_inference::ShapeHandle& shape_handle, | ||||||
|  | |||||||
| @ -42,8 +42,7 @@ constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f; | |||||||
| 
 | 
 | ||||||
| class GraphTransfererTest : public ::testing::Test { | class GraphTransfererTest : public ::testing::Test { | ||||||
|  protected: |  protected: | ||||||
|   void SetUp() final { |   void SetUp() final {} | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   GraphTransferer gt_; |   GraphTransferer gt_; | ||||||
| }; | }; | ||||||
| @ -61,7 +60,7 @@ class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions { | |||||||
|       } |       } | ||||||
|     } |     } | ||||||
|     return -1; |     return -1; | ||||||
| } |   } | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   const std::vector<string> op_types_{"INPUT",   "OUTPUT",  "Conv2D", |   const std::vector<string> op_types_{"INPUT",   "OUTPUT",  "Conv2D", | ||||||
|  | |||||||
| @ -420,7 +420,7 @@ TEST(GraphTransferer, | |||||||
|       false,  // is_text_proto
 |       false,  // is_text_proto
 | ||||||
|       false,  // shape_inference_for_unknown_shape
 |       false,  // shape_inference_for_unknown_shape
 | ||||||
|       true    // dry_run_for_unknown_shape
 |       true    // dry_run_for_unknown_shape
 | ||||||
|       ); |   ); | ||||||
|   ASSERT_TRUE(status.ok()) << status; |   ASSERT_TRUE(status.ok()) << status; | ||||||
|   prof.Stop(); |   prof.Stop(); | ||||||
|   prof.DumpStatistics("LoadGraphFromProtoFile"); |   prof.DumpStatistics("LoadGraphFromProtoFile"); | ||||||
| @ -487,7 +487,7 @@ TEST(GraphTransferer, | |||||||
|       false,  // is_text_proto
 |       false,  // is_text_proto
 | ||||||
|       true,   // shape_inference_for_unknown_shape
 |       true,   // shape_inference_for_unknown_shape
 | ||||||
|       false   // dry_run_for_unknown_shape
 |       false   // dry_run_for_unknown_shape
 | ||||||
|       ); |   ); | ||||||
|   ASSERT_TRUE(status.ok()) << status; |   ASSERT_TRUE(status.ok()) << status; | ||||||
|   prof.Stop(); |   prof.Stop(); | ||||||
|   prof.DumpStatistics("LoadGraphFromProtoFile"); |   prof.DumpStatistics("LoadGraphFromProtoFile"); | ||||||
| @ -556,7 +556,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { | |||||||
|       false,  // is_text_proto
 |       false,  // is_text_proto
 | ||||||
|       false,  // shape_inference_for_unknown_shape
 |       false,  // shape_inference_for_unknown_shape
 | ||||||
|       true    // dry_run_for_unknown_shape
 |       true    // dry_run_for_unknown_shape
 | ||||||
|       ); |   ); | ||||||
|   const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo(); |   const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo(); | ||||||
| 
 | 
 | ||||||
|   ASSERT_TRUE(status.ok()); |   ASSERT_TRUE(status.ok()); | ||||||
| @ -576,7 +576,7 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { | |||||||
|       false,  // is_text_proto
 |       false,  // is_text_proto
 | ||||||
|       true,   // shape_inference_for_unknown_shape
 |       true,   // shape_inference_for_unknown_shape
 | ||||||
|       false   // dry_run_for_unknown_shape
 |       false   // dry_run_for_unknown_shape
 | ||||||
|       ); |   ); | ||||||
|   const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo(); |   const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo(); | ||||||
| 
 | 
 | ||||||
|   ASSERT_TRUE(status.ok()); |   ASSERT_TRUE(status.ok()); | ||||||
|  | |||||||
| @ -71,10 +71,10 @@ class NeonDepthwiseConv2dNativeOp : public BinaryOp<float> { | |||||||
|                                         filter.shape().DebugString())); |                                         filter.shape().DebugString())); | ||||||
| 
 | 
 | ||||||
|     const int32 in_depth = input.dim_size(3); |     const int32 in_depth = input.dim_size(3); | ||||||
|     OP_REQUIRES( |     OP_REQUIRES(context, in_depth == filter.dim_size(2), | ||||||
|         context, in_depth == filter.dim_size(2), |                 errors::InvalidArgument( | ||||||
|         errors::InvalidArgument("input and filter must have the same depth: ", |                     "input and filter must have the same depth: ", in_depth, | ||||||
|                                 in_depth, " vs ", filter.dim_size(2))); |                     " vs ", filter.dim_size(2))); | ||||||
|     const int32 batch = input.dim_size(0); |     const int32 batch = input.dim_size(0); | ||||||
|     const int32 input_rows = input.dim_size(1); |     const int32 input_rows = input.dim_size(1); | ||||||
|     const int32 input_cols = input.dim_size(2); |     const int32 input_cols = input.dim_size(2); | ||||||
|  | |||||||
| @ -131,7 +131,7 @@ inline tensorflow::string* TfCheckOpHelper(::tensorflow::Status v, | |||||||
|   while (auto _result = ::tensorflow::TfCheckOpHelper(val, #val)) \ |   while (auto _result = ::tensorflow::TfCheckOpHelper(val, #val)) \ | ||||||
|   LOG(level) << *(_result) |   LOG(level) << *(_result) | ||||||
| 
 | 
 | ||||||
| #define TF_CHECK_OK(val)  TF_DO_CHECK_OK(val, FATAL) | #define TF_CHECK_OK(val) TF_DO_CHECK_OK(val, FATAL) | ||||||
| #define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL) | #define TF_QCHECK_OK(val) TF_DO_CHECK_OK(val, QFATAL) | ||||||
| 
 | 
 | ||||||
| // DEBUG only version of TF_CHECK_OK.  Compiler still parses 'val' even in opt
 | // DEBUG only version of TF_CHECK_OK.  Compiler still parses 'val' even in opt
 | ||||||
|  | |||||||
| @ -66,7 +66,9 @@ struct EigenEnvironment { | |||||||
|     } |     } | ||||||
|     return Task{ |     return Task{ | ||||||
|         std::unique_ptr<TaskImpl>(new TaskImpl{ |         std::unique_ptr<TaskImpl>(new TaskImpl{ | ||||||
|             std::move(f), Context(ContextKind::kThread), id, |             std::move(f), | ||||||
|  |             Context(ContextKind::kThread), | ||||||
|  |             id, | ||||||
|         }), |         }), | ||||||
|     }; |     }; | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -97,8 +97,8 @@ TEST(ThreadPool, ParallelForWithWorkerId) { | |||||||
|     } |     } | ||||||
|     pool.ParallelForWithWorkerId( |     pool.ParallelForWithWorkerId( | ||||||
|         kWorkItems, kHugeCost, |         kWorkItems, kHugeCost, | ||||||
|         [&threads_running, &work, num_threads]( |         [&threads_running, &work, num_threads](int64 begin, int64 end, | ||||||
|             int64 begin, int64 end, int64 id) { |                                                int64 id) { | ||||||
|           // Store true for the current thread, and assert that another thread
 |           // Store true for the current thread, and assert that another thread
 | ||||||
|           // is not running with the same id.
 |           // is not running with the same id.
 | ||||||
|           ASSERT_LE(0, id); |           ASSERT_LE(0, id); | ||||||
|  | |||||||
| @ -18,12 +18,12 @@ limitations under the License. | |||||||
| #include <mutex> | #include <mutex> | ||||||
| 
 | 
 | ||||||
| #include "sqlite3.h" | #include "sqlite3.h" | ||||||
|  | #include "tensorflow/core/lib/core/refcount.h" | ||||||
| #include "tensorflow/core/lib/core/status.h" | #include "tensorflow/core/lib/core/status.h" | ||||||
| #include "tensorflow/core/lib/core/stringpiece.h" | #include "tensorflow/core/lib/core/stringpiece.h" | ||||||
| #include "tensorflow/core/platform/macros.h" | #include "tensorflow/core/platform/macros.h" | ||||||
| #include "tensorflow/core/platform/thread_annotations.h" | #include "tensorflow/core/platform/thread_annotations.h" | ||||||
| #include "tensorflow/core/platform/types.h" | #include "tensorflow/core/platform/types.h" | ||||||
| #include "tensorflow/core/lib/core/refcount.h" |  | ||||||
| 
 | 
 | ||||||
| /// TensorFlow SQLite Veneer
 | /// TensorFlow SQLite Veneer
 | ||||||
| ///
 | ///
 | ||||||
| @ -121,10 +121,7 @@ class LOCKABLE Sqlite : public core::RefCounted { | |||||||
| 
 | 
 | ||||||
|   Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit, |   Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit, | ||||||
|          sqlite3_stmt* rollback) noexcept |          sqlite3_stmt* rollback) noexcept | ||||||
|       : db_(db), |       : db_(db), begin_(begin), commit_(commit), rollback_(rollback) {} | ||||||
|         begin_(begin), |  | ||||||
|         commit_(commit), |  | ||||||
|         rollback_(rollback) {} |  | ||||||
| 
 | 
 | ||||||
|   sqlite3* const db_; |   sqlite3* const db_; | ||||||
|   sqlite3_stmt* const begin_; |   sqlite3_stmt* const begin_; | ||||||
| @ -233,7 +230,8 @@ class SqliteStatement { | |||||||
|   /// freed until this statement is Reset() or finalized.
 |   /// freed until this statement is Reset() or finalized.
 | ||||||
|   void BindText(int parameter, const StringPiece& text) { |   void BindText(int parameter, const StringPiece& text) { | ||||||
|     Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), |     Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), | ||||||
|                                SQLITE_TRANSIENT, SQLITE_UTF8), parameter); |                                SQLITE_TRANSIENT, SQLITE_UTF8), | ||||||
|  |            parameter); | ||||||
|     size_ += text.size(); |     size_ += text.size(); | ||||||
|   } |   } | ||||||
|   void BindText(const char* parameter, const StringPiece& text) { |   void BindText(const char* parameter, const StringPiece& text) { | ||||||
| @ -241,7 +239,8 @@ class SqliteStatement { | |||||||
|   } |   } | ||||||
|   void BindTextUnsafe(int parameter, const StringPiece& text) { |   void BindTextUnsafe(int parameter, const StringPiece& text) { | ||||||
|     Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), |     Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), | ||||||
|                                SQLITE_STATIC, SQLITE_UTF8), parameter); |                                SQLITE_STATIC, SQLITE_UTF8), | ||||||
|  |            parameter); | ||||||
|     size_ += text.size(); |     size_ += text.size(); | ||||||
|   } |   } | ||||||
|   void BindTextUnsafe(const char* parameter, const StringPiece& text) { |   void BindTextUnsafe(const char* parameter, const StringPiece& text) { | ||||||
| @ -254,7 +253,8 @@ class SqliteStatement { | |||||||
|   /// freed until this statement is Reset() or finalized.
 |   /// freed until this statement is Reset() or finalized.
 | ||||||
|   void BindBlob(int parameter, const StringPiece& blob) { |   void BindBlob(int parameter, const StringPiece& blob) { | ||||||
|     Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), |     Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), | ||||||
|                                SQLITE_TRANSIENT), parameter); |                                SQLITE_TRANSIENT), | ||||||
|  |            parameter); | ||||||
|     size_ += blob.size(); |     size_ += blob.size(); | ||||||
|   } |   } | ||||||
|   void BindBlob(const char* parameter, const StringPiece& blob) { |   void BindBlob(const char* parameter, const StringPiece& blob) { | ||||||
| @ -262,7 +262,8 @@ class SqliteStatement { | |||||||
|   } |   } | ||||||
|   void BindBlobUnsafe(int parameter, const StringPiece& blob) { |   void BindBlobUnsafe(int parameter, const StringPiece& blob) { | ||||||
|     Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), |     Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), | ||||||
|                                SQLITE_STATIC), parameter); |                                SQLITE_STATIC), | ||||||
|  |            parameter); | ||||||
|     size_ += blob.size(); |     size_ += blob.size(); | ||||||
|   } |   } | ||||||
|   void BindBlobUnsafe(const char* parameter, const StringPiece& text) { |   void BindBlobUnsafe(const char* parameter, const StringPiece& text) { | ||||||
| @ -320,9 +321,7 @@ class SqliteStatement { | |||||||
| 
 | 
 | ||||||
|   /// \brief Move constructor, after which <other> is reset to empty.
 |   /// \brief Move constructor, after which <other> is reset to empty.
 | ||||||
|   SqliteStatement(SqliteStatement&& other) noexcept |   SqliteStatement(SqliteStatement&& other) noexcept | ||||||
|       : db_(other.db_), |       : db_(other.db_), stmt_(other.stmt_), bind_error_(other.bind_error_) { | ||||||
|         stmt_(other.stmt_), |  | ||||||
|         bind_error_(other.bind_error_) { |  | ||||||
|     other.db_ = nullptr; |     other.db_ = nullptr; | ||||||
|     other.stmt_ = nullptr; |     other.stmt_ = nullptr; | ||||||
|     other.bind_error_ = SQLITE_OK; |     other.bind_error_ = SQLITE_OK; | ||||||
|  | |||||||
| @ -33,9 +33,7 @@ class SqliteTest : public ::testing::Test { | |||||||
|     db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie(); |     db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   void TearDown() override { |   void TearDown() override { db_->Unref(); } | ||||||
|     db_->Unref(); |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   Sqlite* db_; |   Sqlite* db_; | ||||||
|   bool is_done_; |   bool is_done_; | ||||||
| @ -213,7 +211,7 @@ TEST_F(SqliteTest, BindFailed) { | |||||||
|   Status s = stmt.StepOnce(); |   Status s = stmt.StepOnce(); | ||||||
|   EXPECT_NE(string::npos, |   EXPECT_NE(string::npos, | ||||||
|             s.error_message().find("INSERT INTO T (a) VALUES (123)")) |             s.error_message().find("INSERT INTO T (a) VALUES (123)")) | ||||||
|             << s.error_message(); |       << s.error_message(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(SqliteTest, SnappyExtension) { | TEST_F(SqliteTest, SnappyExtension) { | ||||||
| @ -226,7 +224,7 @@ TEST_F(SqliteTest, SnappyBinaryCompatibility) { | |||||||
|   EXPECT_EQ( |   EXPECT_EQ( | ||||||
|       "today is the end of the republic", |       "today is the end of the republic", | ||||||
|       db_->PrepareOrDie("SELECT UNSNAP(X'03207C746F6461792069732074686520656E64" |       db_->PrepareOrDie("SELECT UNSNAP(X'03207C746F6461792069732074686520656E64" | ||||||
|                             "206F66207468652072657075626C6963')") |                         "206F66207468652072657075626C6963')") | ||||||
|           .StepOnceOrDie() |           .StepOnceOrDie() | ||||||
|           .ColumnString(0)); |           .ColumnString(0)); | ||||||
| } | } | ||||||
|  | |||||||
| @ -55,22 +55,21 @@ namespace gtl { | |||||||
| template <typename F> | template <typename F> | ||||||
| class Cleanup { | class Cleanup { | ||||||
|  public: |  public: | ||||||
|   Cleanup() |   Cleanup() : released_(true), f_() {} | ||||||
|       : released_(true), f_() {} |  | ||||||
| 
 | 
 | ||||||
|   template <typename G> |   template <typename G> | ||||||
|   explicit Cleanup(G&& f)  // NOLINT
 |   explicit Cleanup(G&& f)          // NOLINT
 | ||||||
|       : f_(std::forward<G>(f)) {}  // NOLINT(build/c++11)
 |       : f_(std::forward<G>(f)) {}  // NOLINT(build/c++11)
 | ||||||
| 
 | 
 | ||||||
|   Cleanup(Cleanup&& src)  // NOLINT
 |   Cleanup(Cleanup&& src)  // NOLINT
 | ||||||
|       : released_(src.is_released()), f_(src.release()) { } |       : released_(src.is_released()), f_(src.release()) {} | ||||||
| 
 | 
 | ||||||
|   // Implicitly move-constructible from any compatible Cleanup<G>.
 |   // Implicitly move-constructible from any compatible Cleanup<G>.
 | ||||||
|   // The source will be released as if src.release() were called.
 |   // The source will be released as if src.release() were called.
 | ||||||
|   // A moved-from Cleanup can be safely destroyed or reassigned.
 |   // A moved-from Cleanup can be safely destroyed or reassigned.
 | ||||||
|   template <typename G> |   template <typename G> | ||||||
|   Cleanup(Cleanup<G>&& src)  // NOLINT
 |   Cleanup(Cleanup<G>&& src)  // NOLINT
 | ||||||
|       : released_(src.is_released()), f_(src.release()) { } |       : released_(src.is_released()), f_(src.release()) {} | ||||||
| 
 | 
 | ||||||
|   // Assignment to a Cleanup object behaves like destroying it
 |   // Assignment to a Cleanup object behaves like destroying it
 | ||||||
|   // and making a new one in its place, analogous to unique_ptr
 |   // and making a new one in its place, analogous to unique_ptr
 | ||||||
| @ -102,8 +101,8 @@ class Cleanup { | |||||||
|   F f_; |   F f_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| template <int&... ExplicitParameterBarrier, | template <int&... ExplicitParameterBarrier, typename F, | ||||||
|           typename F, typename DecayF = typename std::decay<F>::type> |           typename DecayF = typename std::decay<F>::type> | ||||||
| TF_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) { | TF_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) { | ||||||
|   return Cleanup<DecayF>(std::forward<F>(f)); |   return Cleanup<DecayF>(std::forward<F>(f)); | ||||||
| } | } | ||||||
|  | |||||||
| @ -65,15 +65,14 @@ TEST(CleanupTest, Release) { | |||||||
| TEST(FinallyTest, TypeErasedWithoutFactory) { | TEST(FinallyTest, TypeErasedWithoutFactory) { | ||||||
|   string s = "active"; |   string s = "active"; | ||||||
|   { |   { | ||||||
|     AnyCleanup s_cleaner([&s]{ s.append(" clean"); }); |     AnyCleanup s_cleaner([&s] { s.append(" clean"); }); | ||||||
|     EXPECT_EQ("active", s); |     EXPECT_EQ("active", s); | ||||||
|   } |   } | ||||||
|   EXPECT_EQ("active clean", s); |   EXPECT_EQ("active clean", s); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| struct Appender { | struct Appender { | ||||||
|   Appender(string* s, const string& msg) |   Appender(string* s, const string& msg) : s_(s), msg_(msg) {} | ||||||
|       : s_(s), msg_(msg) {} |  | ||||||
|   void operator()() const { s_->append(msg_); } |   void operator()() const { s_->append(msg_); } | ||||||
|   string* s_; |   string* s_; | ||||||
|   string msg_; |   string msg_; | ||||||
| @ -163,7 +162,12 @@ class CleanupReferenceTest : public ::testing::Test { | |||||||
|     int* i; |     int* i; | ||||||
|     F(int* cp, int* i) : cp(cp), i(i) {} |     F(int* cp, int* i) : cp(cp), i(i) {} | ||||||
|     F(const F& o) : cp(o.cp), i(o.i) { ++*cp; } |     F(const F& o) : cp(o.cp), i(o.i) { ++*cp; } | ||||||
|     F& operator=(const F& o) { cp = o.cp; i = o.i; ++*cp; return *this; } |     F& operator=(const F& o) { | ||||||
|  |       cp = o.cp; | ||||||
|  |       i = o.i; | ||||||
|  |       ++*cp; | ||||||
|  |       return *this; | ||||||
|  |     } | ||||||
|     F(F&&) = default; |     F(F&&) = default; | ||||||
|     F& operator=(F&&) = default; |     F& operator=(F&&) = default; | ||||||
|     void operator()() const { ++*i; } |     void operator()() const { ++*i; } | ||||||
| @ -279,7 +283,7 @@ BENCHMARK(BM_AnyCleanup); | |||||||
| 
 | 
 | ||||||
| void BM_AnyCleanupNoFactory(int iters) { | void BM_AnyCleanupNoFactory(int iters) { | ||||||
|   while (iters--) { |   while (iters--) { | ||||||
|     AnyCleanup fin([]{Incr();}); |     AnyCleanup fin([] { Incr(); }); | ||||||
|   } |   } | ||||||
| } | } | ||||||
| BENCHMARK(BM_AnyCleanupNoFactory); | BENCHMARK(BM_AnyCleanupNoFactory); | ||||||
|  | |||||||
| @ -31,12 +31,12 @@ limitations under the License. | |||||||
| #ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ | #ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ | ||||||
| #define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ | #define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_ | ||||||
| 
 | 
 | ||||||
| #include <cstddef> |  | ||||||
| #include <stddef.h> | #include <stddef.h> | ||||||
| #include <stdlib.h> | #include <stdlib.h> | ||||||
| #include <string.h> | #include <string.h> | ||||||
| #include <sys/types.h> | #include <sys/types.h> | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
|  | #include <cstddef> | ||||||
| #include <iterator> | #include <iterator> | ||||||
| #include <memory> | #include <memory> | ||||||
| #include <type_traits> | #include <type_traits> | ||||||
| @ -407,7 +407,7 @@ class InlinedVector { | |||||||
|   }; |   }; | ||||||
|   // 2) Construct a T with args at not-yet-initialized memory pointed by dst.
 |   // 2) Construct a T with args at not-yet-initialized memory pointed by dst.
 | ||||||
|   struct Construct { |   struct Construct { | ||||||
|     template<class... Args> |     template <class... Args> | ||||||
|     void operator()(T* dst, Args&&... args) const { |     void operator()(T* dst, Args&&... args) const { | ||||||
|       new (dst) T(std::forward<Args>(args)...); |       new (dst) T(std::forward<Args>(args)...); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -255,13 +255,13 @@ class IntType { | |||||||
|     value_ op arg_value;                             \ |     value_ op arg_value;                             \ | ||||||
|     return *this;                                    \ |     return *this;                                    \ | ||||||
|   } |   } | ||||||
|   INT_TYPE_ASSIGNMENT_OP(+= ); |   INT_TYPE_ASSIGNMENT_OP(+=); | ||||||
|   INT_TYPE_ASSIGNMENT_OP(-= ); |   INT_TYPE_ASSIGNMENT_OP(-=); | ||||||
|   INT_TYPE_ASSIGNMENT_OP(*= ); |   INT_TYPE_ASSIGNMENT_OP(*=); | ||||||
|   INT_TYPE_ASSIGNMENT_OP(/= ); |   INT_TYPE_ASSIGNMENT_OP(/=); | ||||||
|   INT_TYPE_ASSIGNMENT_OP(<<= );  // NOLINT
 |   INT_TYPE_ASSIGNMENT_OP(<<=);  // NOLINT
 | ||||||
|   INT_TYPE_ASSIGNMENT_OP(>>= );  // NOLINT
 |   INT_TYPE_ASSIGNMENT_OP(>>=);  // NOLINT
 | ||||||
|   INT_TYPE_ASSIGNMENT_OP(%= ); |   INT_TYPE_ASSIGNMENT_OP(%=); | ||||||
| #undef INT_TYPE_ASSIGNMENT_OP | #undef INT_TYPE_ASSIGNMENT_OP | ||||||
| 
 | 
 | ||||||
|   ThisType& operator=(ValueType arg_value) { |   ThisType& operator=(ValueType arg_value) { | ||||||
| @ -314,10 +314,10 @@ std::ostream& operator<<(std::ostream& os,  // NOLINT | |||||||
| INT_TYPE_ARITHMETIC_OP(+); | INT_TYPE_ARITHMETIC_OP(+); | ||||||
| INT_TYPE_ARITHMETIC_OP(-); | INT_TYPE_ARITHMETIC_OP(-); | ||||||
| INT_TYPE_ARITHMETIC_OP(*); | INT_TYPE_ARITHMETIC_OP(*); | ||||||
| INT_TYPE_ARITHMETIC_OP(/ ); | INT_TYPE_ARITHMETIC_OP(/); | ||||||
| INT_TYPE_ARITHMETIC_OP(<< );  // NOLINT
 | INT_TYPE_ARITHMETIC_OP(<<);  // NOLINT
 | ||||||
| INT_TYPE_ARITHMETIC_OP(>> );  // NOLINT
 | INT_TYPE_ARITHMETIC_OP(>>);  // NOLINT
 | ||||||
| INT_TYPE_ARITHMETIC_OP(% ); | INT_TYPE_ARITHMETIC_OP(%); | ||||||
| #undef INT_TYPE_ARITHMETIC_OP | #undef INT_TYPE_ARITHMETIC_OP | ||||||
| 
 | 
 | ||||||
| // -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------
 | // -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------
 | ||||||
| @ -345,12 +345,12 @@ INT_TYPE_ARITHMETIC_OP(% ); | |||||||
|       IntType<IntTypeName, ValueType> id) {                      \ |       IntType<IntTypeName, ValueType> id) {                      \ | ||||||
|     return val op id.value();                                    \ |     return val op id.value();                                    \ | ||||||
|   } |   } | ||||||
| INT_TYPE_COMPARISON_OP(== );  // NOLINT
 | INT_TYPE_COMPARISON_OP(==);  // NOLINT
 | ||||||
| INT_TYPE_COMPARISON_OP(!= );  // NOLINT
 | INT_TYPE_COMPARISON_OP(!=);  // NOLINT
 | ||||||
| INT_TYPE_COMPARISON_OP(< );   // NOLINT
 | INT_TYPE_COMPARISON_OP(<);   // NOLINT
 | ||||||
| INT_TYPE_COMPARISON_OP(<= );  // NOLINT
 | INT_TYPE_COMPARISON_OP(<=);  // NOLINT
 | ||||||
| INT_TYPE_COMPARISON_OP(> );   // NOLINT
 | INT_TYPE_COMPARISON_OP(>);   // NOLINT
 | ||||||
| INT_TYPE_COMPARISON_OP(>= );  // NOLINT
 | INT_TYPE_COMPARISON_OP(>=);  // NOLINT
 | ||||||
| #undef INT_TYPE_COMPARISON_OP | #undef INT_TYPE_COMPARISON_OP | ||||||
| 
 | 
 | ||||||
| }  // namespace gtl
 | }  // namespace gtl
 | ||||||
|  | |||||||
| @ -42,7 +42,8 @@ class IntTypeTest : public ::testing::Test { | |||||||
| 
 | 
 | ||||||
| // All tests below will be executed on all supported IntTypes.
 | // All tests below will be executed on all supported IntTypes.
 | ||||||
| typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT, | typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT, | ||||||
|                          Int64_IT, UInt64_IT, Long_IT> SupportedIntTypes; |                          Int64_IT, UInt64_IT, Long_IT> | ||||||
|  |     SupportedIntTypes; | ||||||
| 
 | 
 | ||||||
| TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes); | TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes); | ||||||
| 
 | 
 | ||||||
| @ -232,7 +233,8 @@ TYPED_TEST(IntTypeTest, TestOperators) { | |||||||
| 
 | 
 | ||||||
| TYPED_TEST(IntTypeTest, TestHashFunctor) { | TYPED_TEST(IntTypeTest, TestHashFunctor) { | ||||||
|   std::unordered_map<typename TestFixture::T, char, |   std::unordered_map<typename TestFixture::T, char, | ||||||
|                      typename TestFixture::T::Hasher> map; |                      typename TestFixture::T::Hasher> | ||||||
|  |       map; | ||||||
|   typename TestFixture::T a(0); |   typename TestFixture::T a(0); | ||||||
|   map[a] = 'c'; |   map[a] = 'c'; | ||||||
|   EXPECT_EQ('c', map[a]); |   EXPECT_EQ('c', map[a]); | ||||||
|  | |||||||
| @ -593,12 +593,12 @@ class optional : private internal_optional::optional_data<T>, | |||||||
|     assert(this->engaged_); |     assert(this->engaged_); | ||||||
|     return this->pointer(); |     return this->pointer(); | ||||||
|   } |   } | ||||||
|   constexpr const T& operator*() const & { return reference(); } |   constexpr const T& operator*() const& { return reference(); } | ||||||
|   T& operator*() & { |   T& operator*() & { | ||||||
|     assert(this->engaged_); |     assert(this->engaged_); | ||||||
|     return reference(); |     return reference(); | ||||||
|   } |   } | ||||||
|   constexpr const T&& operator*() const && { return std::move(reference()); } |   constexpr const T&& operator*() const&& { return std::move(reference()); } | ||||||
|   T&& operator*() && { |   T&& operator*() && { | ||||||
|     assert(this->engaged_); |     assert(this->engaged_); | ||||||
|     return std::move(reference()); |     return std::move(reference()); | ||||||
| @ -621,7 +621,7 @@ class optional : private internal_optional::optional_data<T>, | |||||||
|   // Use `opt.value()` to get a reference to underlying value.  The constness
 |   // Use `opt.value()` to get a reference to underlying value.  The constness
 | ||||||
|   // and lvalue/rvalue-ness of `opt` is preserved to the view of the T
 |   // and lvalue/rvalue-ness of `opt` is preserved to the view of the T
 | ||||||
|   // subobject.
 |   // subobject.
 | ||||||
|   const T& value() const & { |   const T& value() const& { | ||||||
|     CHECK(*this) << "Bad optional access"; |     CHECK(*this) << "Bad optional access"; | ||||||
|     return reference(); |     return reference(); | ||||||
|   } |   } | ||||||
| @ -633,7 +633,7 @@ class optional : private internal_optional::optional_data<T>, | |||||||
|     CHECK(*this) << "Bad optional access"; |     CHECK(*this) << "Bad optional access"; | ||||||
|     return std::move(reference()); |     return std::move(reference()); | ||||||
|   } |   } | ||||||
|   const T&& value() const && {  // NOLINT(build/c++11)
 |   const T&& value() const&& {  // NOLINT(build/c++11)
 | ||||||
|     CHECK(*this) << "Bad optional access"; |     CHECK(*this) << "Bad optional access"; | ||||||
|     return std::move(reference()); |     return std::move(reference()); | ||||||
|   } |   } | ||||||
| @ -641,7 +641,7 @@ class optional : private internal_optional::optional_data<T>, | |||||||
|   // Use `opt.value_or(val)` to get either the value of T or the given default
 |   // Use `opt.value_or(val)` to get either the value of T or the given default
 | ||||||
|   // `val` in the empty case.
 |   // `val` in the empty case.
 | ||||||
|   template <class U> |   template <class U> | ||||||
|   constexpr T value_or(U&& v) const & { |   constexpr T value_or(U&& v) const& { | ||||||
|     return static_cast<bool>(*this) ? **this |     return static_cast<bool>(*this) ? **this | ||||||
|                                     : static_cast<T>(std::forward<U>(v)); |                                     : static_cast<T>(std::forward<U>(v)); | ||||||
|   } |   } | ||||||
| @ -656,8 +656,8 @@ class optional : private internal_optional::optional_data<T>, | |||||||
|   constexpr const T& reference() const { return *this->pointer(); } |   constexpr const T& reference() const { return *this->pointer(); } | ||||||
|   T& reference() { return *(this->pointer()); } |   T& reference() { return *(this->pointer()); } | ||||||
| 
 | 
 | ||||||
|   // T constraint checks.  You can't have an optional of nullopt_t, in_place_t or
 |   // T constraint checks.  You can't have an optional of nullopt_t, in_place_t
 | ||||||
|   // a reference.
 |   // or a reference.
 | ||||||
|   static_assert( |   static_assert( | ||||||
|       !std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value, |       !std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value, | ||||||
|       "optional<nullopt_t> is not allowed."); |       "optional<nullopt_t> is not allowed."); | ||||||
|  | |||||||
| @ -24,17 +24,29 @@ limitations under the License. | |||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| using tensorflow::gtl::optional; |  | ||||||
| using tensorflow::gtl::nullopt; |  | ||||||
| using tensorflow::gtl::nullopt_t; |  | ||||||
| using tensorflow::gtl::in_place; | using tensorflow::gtl::in_place; | ||||||
| using tensorflow::gtl::in_place_t; | using tensorflow::gtl::in_place_t; | ||||||
| using tensorflow::gtl::make_optional; | using tensorflow::gtl::make_optional; | ||||||
|  | using tensorflow::gtl::nullopt; | ||||||
|  | using tensorflow::gtl::nullopt_t; | ||||||
|  | using tensorflow::gtl::optional; | ||||||
| 
 | 
 | ||||||
| template <typename T> string TypeQuals(T&) { return "&"; } | template <typename T> | ||||||
| template <typename T> string TypeQuals(T&&) { return "&&"; } | string TypeQuals(T&) { | ||||||
| template <typename T> string TypeQuals(const T&) { return "c&"; } |   return "&"; | ||||||
| template <typename T> string TypeQuals(const T&&) { return "c&&"; } | } | ||||||
|  | template <typename T> | ||||||
|  | string TypeQuals(T&&) { | ||||||
|  |   return "&&"; | ||||||
|  | } | ||||||
|  | template <typename T> | ||||||
|  | string TypeQuals(const T&) { | ||||||
|  |   return "c&"; | ||||||
|  | } | ||||||
|  | template <typename T> | ||||||
|  | string TypeQuals(const T&&) { | ||||||
|  |   return "c&&"; | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| struct StructorListener { | struct StructorListener { | ||||||
|   int construct0 = 0; |   int construct0 = 0; | ||||||
|  | |||||||
| @ -28,10 +28,10 @@ limitations under the License. | |||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | using tensorflow::string; | ||||||
| using tensorflow::gtl::TopN; | using tensorflow::gtl::TopN; | ||||||
| using tensorflow::random::PhiloxRandom; | using tensorflow::random::PhiloxRandom; | ||||||
| using tensorflow::random::SimplePhilox; | using tensorflow::random::SimplePhilox; | ||||||
| using tensorflow::string; |  | ||||||
| 
 | 
 | ||||||
| // Move the contents from an owned raw pointer, returning by value.
 | // Move the contents from an owned raw pointer, returning by value.
 | ||||||
| // Objects are easier to manage by value.
 | // Objects are easier to manage by value.
 | ||||||
|  | |||||||
| @ -22,6 +22,6 @@ namespace compression { | |||||||
| const char kNone[] = ""; | const char kNone[] = ""; | ||||||
| const char kGzip[] = "GZIP"; | const char kGzip[] = "GZIP"; | ||||||
| 
 | 
 | ||||||
| } | }  // namespace compression
 | ||||||
| } | }  // namespace io
 | ||||||
| } | }  // namespace tensorflow
 | ||||||
|  | |||||||
| @ -23,8 +23,8 @@ namespace compression { | |||||||
| extern const char kNone[]; | extern const char kNone[]; | ||||||
| extern const char kGzip[]; | extern const char kGzip[]; | ||||||
| 
 | 
 | ||||||
| } | }  // namespace compression
 | ||||||
| } | }  // namespace io
 | ||||||
| } | }  // namespace tensorflow
 | ||||||
| 
 | 
 | ||||||
| #endif  // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
 | #endif  // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
 | ||||||
|  | |||||||
| @ -207,7 +207,7 @@ Status RecordReader::SkipNBytes(uint64 offset) { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   return Status::OK(); |   return Status::OK(); | ||||||
| } | }  // namespace io
 | ||||||
| 
 | 
 | ||||||
| SequentialRecordReader::SequentialRecordReader( | SequentialRecordReader::SequentialRecordReader( | ||||||
|     RandomAccessFile* file, const RecordReaderOptions& options) |     RandomAccessFile* file, const RecordReaderOptions& options) | ||||||
|  | |||||||
| @ -218,8 +218,8 @@ TEST_F(RecordioTest, RandomRead) { | |||||||
| 
 | 
 | ||||||
| // Tests of all the error paths in log_reader.cc follow:
 | // Tests of all the error paths in log_reader.cc follow:
 | ||||||
| static void AssertHasSubstr(StringPiece s, StringPiece expected) { | static void AssertHasSubstr(StringPiece s, StringPiece expected) { | ||||||
|   EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain " |   EXPECT_TRUE(StringPiece(s).contains(expected)) | ||||||
|                                                  << expected; |       << s << " does not contain " << expected; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(RecordioTest, ReadError) { | TEST_F(RecordioTest, ReadError) { | ||||||
|  | |||||||
| @ -197,8 +197,8 @@ bool CommonInitDecode(StringPiece png_string, int desired_channels, | |||||||
|                       int desired_channel_bits, DecodeContext* context) { |                       int desired_channel_bits, DecodeContext* context) { | ||||||
|   CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) |   CHECK(desired_channel_bits == 8 || desired_channel_bits == 16) | ||||||
|       << "desired_channel_bits = " << desired_channel_bits; |       << "desired_channel_bits = " << desired_channel_bits; | ||||||
|   CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = " |   CHECK(0 <= desired_channels && desired_channels <= 4) | ||||||
|                                                         << desired_channels; |       << "desired_channels = " << desired_channels; | ||||||
|   context->error_condition = false; |   context->error_condition = false; | ||||||
|   context->channels = desired_channels; |   context->channels = desired_channels; | ||||||
|   context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, |   context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context, | ||||||
|  | |||||||
| @ -35,8 +35,8 @@ void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p, | |||||||
|                  int64 size) { |                  int64 size) { | ||||||
|   const int granularity = Distribution::kResultElementCount; |   const int granularity = Distribution::kResultElementCount; | ||||||
| 
 | 
 | ||||||
|   CHECK(size % granularity == 0) << " size: " << size |   CHECK(size % granularity == 0) | ||||||
|                                  << " granularity: " << granularity; |       << " size: " << size << " granularity: " << granularity; | ||||||
| 
 | 
 | ||||||
|   Distribution dist; |   Distribution dist; | ||||||
|   for (int i = 0; i < size; i += granularity) { |   for (int i = 0; i < size; i += granularity) { | ||||||
|  | |||||||
| @ -17,8 +17,8 @@ limitations under the License. | |||||||
| #define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ | #define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ | ||||||
| 
 | 
 | ||||||
| #define _USE_MATH_DEFINES | #define _USE_MATH_DEFINES | ||||||
| #include <cmath> |  | ||||||
| #include <math.h> | #include <math.h> | ||||||
|  | #include <cmath> | ||||||
| #undef _USE_MATH_DEFINES | #undef _USE_MATH_DEFINES | ||||||
| 
 | 
 | ||||||
| #include <string.h> | #include <string.h> | ||||||
| @ -27,7 +27,6 @@ limitations under the License. | |||||||
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | ||||||
| #include "tensorflow/core/lib/random/philox_random.h" | #include "tensorflow/core/lib/random/philox_random.h" | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| namespace random { | namespace random { | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -45,8 +45,8 @@ void FillRandomsWithSingles(PhiloxRandom gen, | |||||||
|                             int64 size) { |                             int64 size) { | ||||||
|   int granularity = Distribution::kResultElementCount; |   int granularity = Distribution::kResultElementCount; | ||||||
| 
 | 
 | ||||||
|   CHECK(size % granularity == 0) << " size: " << size |   CHECK(size % granularity == 0) | ||||||
|                                  << " granularity: " << granularity; |       << " size: " << size << " granularity: " << granularity; | ||||||
| 
 | 
 | ||||||
|   SingleSampleAdapter<PhiloxRandom> single_samples(&gen); |   SingleSampleAdapter<PhiloxRandom> single_samples(&gen); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -472,7 +472,8 @@ void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) { | |||||||
|   // buf = val in network byte order, sign extended to 10 bytes
 |   // buf = val in network byte order, sign extended to 10 bytes
 | ||||||
|   const char sign_byte = val < 0 ? '\xff' : '\0'; |   const char sign_byte = val < 0 ? '\xff' : '\0'; | ||||||
|   char buf[10] = { |   char buf[10] = { | ||||||
|       sign_byte, sign_byte, |       sign_byte, | ||||||
|  |       sign_byte, | ||||||
|   }; |   }; | ||||||
|   StoreBigEndian64(buf + 2, val); |   StoreBigEndian64(buf + 2, val); | ||||||
|   static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch"); |   static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch"); | ||||||
|  | |||||||
| @ -126,7 +126,7 @@ class AlphaNum { | |||||||
|       : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {} |       : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {} | ||||||
| 
 | 
 | ||||||
|   AlphaNum(const Eigen::half &f);  // NOLINT(runtime/explicit)
 |   AlphaNum(const Eigen::half &f);  // NOLINT(runtime/explicit)
 | ||||||
|   AlphaNum(Hex hex);  // NOLINT(runtime/explicit)
 |   AlphaNum(Hex hex);               // NOLINT(runtime/explicit)
 | ||||||
| 
 | 
 | ||||||
|   AlphaNum(const char *c_str) : piece_(c_str) {}   // NOLINT(runtime/explicit)
 |   AlphaNum(const char *c_str) : piece_(c_str) {}   // NOLINT(runtime/explicit)
 | ||||||
|   AlphaNum(const StringPiece &pc) : piece_(pc) {}  // NOLINT(runtime/explicit)
 |   AlphaNum(const StringPiece &pc) : piece_(pc) {}  // NOLINT(runtime/explicit)
 | ||||||
|  | |||||||
| @ -25,8 +25,9 @@ namespace tensorflow { | |||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| TEST(BackwardsCompatibilityTest, IsCompatible) { | TEST(BackwardsCompatibilityTest, IsCompatible) { | ||||||
|   OpCompatibilityLib compatibility( |   OpCompatibilityLib compatibility("tensorflow/core/ops", | ||||||
|       "tensorflow/core/ops", strings::StrCat("v", TF_MAJOR_VERSION), nullptr); |                                    strings::StrCat("v", TF_MAJOR_VERSION), | ||||||
|  |                                    nullptr); | ||||||
| 
 | 
 | ||||||
|   Env* env = Env::Default(); |   Env* env = Env::Default(); | ||||||
|   int changed_ops = 0; |   int changed_ops = 0; | ||||||
|  | |||||||
| @ -18,9 +18,9 @@ limitations under the License. | |||||||
| #include <arpa/inet.h> | #include <arpa/inet.h> | ||||||
| #include <netdb.h> | #include <netdb.h> | ||||||
| #else | #else | ||||||
|  | #include <Windows.h> | ||||||
| #include <winsock2.h> | #include <winsock2.h> | ||||||
| #include <ws2tcpip.h> | #include <ws2tcpip.h> | ||||||
| #include <Windows.h> |  | ||||||
| #endif | #endif | ||||||
| #include <sys/types.h> | #include <sys/types.h> | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -38,8 +38,7 @@ class FakeHttpRequest : public CurlHttpRequest { | |||||||
|  public: |  public: | ||||||
|   /// Return the response for the given request.
 |   /// Return the response for the given request.
 | ||||||
|   FakeHttpRequest(const string& request, const string& response) |   FakeHttpRequest(const string& request, const string& response) | ||||||
|       : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) { |       : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {} | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   /// Return the response with headers for the given request.
 |   /// Return the response with headers for the given request.
 | ||||||
|   FakeHttpRequest(const string& request, const string& response, |   FakeHttpRequest(const string& request, const string& response, | ||||||
|  | |||||||
| @ -160,12 +160,12 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { | |||||||
|   ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key)); |   ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key)); | ||||||
|   ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(), |   ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(), | ||||||
|                                       header_dot_claim.size())); |                                       header_dot_claim.size())); | ||||||
|   ASSERT_EQ( |   ASSERT_EQ(1, | ||||||
|       1, |             EVP_DigestVerifyFinal( | ||||||
|       EVP_DigestVerifyFinal( |                 md_ctx, | ||||||
|           md_ctx, const_cast<unsigned char*>( |                 const_cast<unsigned char*>( | ||||||
|                       reinterpret_cast<const unsigned char*>(signature.data())), |                     reinterpret_cast<const unsigned char*>(signature.data())), | ||||||
|           signature.size())); |                 signature.size())); | ||||||
|   EVP_MD_CTX_cleanup(md_ctx); |   EVP_MD_CTX_cleanup(md_ctx); | ||||||
| 
 | 
 | ||||||
|   // Free all the crypto-related resources.
 |   // Free all the crypto-related resources.
 | ||||||
|  | |||||||
| @ -25,7 +25,6 @@ namespace tensorflow { | |||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| class RetryingRandomAccessFile : public RandomAccessFile { | class RetryingRandomAccessFile : public RandomAccessFile { | ||||||
|  public: |  public: | ||||||
|   RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file, |   RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file, | ||||||
|  | |||||||
| @ -27,8 +27,7 @@ TEST(CudaLibdevicePathTest, LibdevicePath) { | |||||||
|   VLOG(2) << "Libdevice root = " << LibdeviceRoot(); |   VLOG(2) << "Libdevice root = " << LibdeviceRoot(); | ||||||
|   std::vector<string> libdevice_files; |   std::vector<string> libdevice_files; | ||||||
|   TF_EXPECT_OK(Env::Default()->GetMatchingPaths( |   TF_EXPECT_OK(Env::Default()->GetMatchingPaths( | ||||||
|       io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), |       io::JoinPath(LibdeviceRoot(), "libdevice.*.bc"), &libdevice_files)); | ||||||
|       &libdevice_files)); |  | ||||||
|   EXPECT_LT(0, libdevice_files.size()); |   EXPECT_LT(0, libdevice_files.size()); | ||||||
| } | } | ||||||
| #endif | #endif | ||||||
|  | |||||||
| @ -579,8 +579,10 @@ Status DeviceTracerImpl::Collect(StepStatsCollector *collector) { | |||||||
|   // TODO(pbar) Handle device IDs and prefix properly.
 |   // TODO(pbar) Handle device IDs and prefix properly.
 | ||||||
|   const string prefix = ""; |   const string prefix = ""; | ||||||
|   const int id = 0; |   const int id = 0; | ||||||
|   const string stream_device = strings::StrCat(prefix, "/device:GPU:", id, "/stream:"); |   const string stream_device = | ||||||
|   const string memcpy_device = strings::StrCat(prefix, "/device:GPU:", id, "/memcpy"); |       strings::StrCat(prefix, "/device:GPU:", id, "/stream:"); | ||||||
|  |   const string memcpy_device = | ||||||
|  |       strings::StrCat(prefix, "/device:GPU:", id, "/memcpy"); | ||||||
| 
 | 
 | ||||||
|   mutex_lock l2(trace_mu_); |   mutex_lock l2(trace_mu_); | ||||||
|   for (const auto &rec : kernel_records_) { |   for (const auto &rec : kernel_records_) { | ||||||
|  | |||||||
| @ -83,15 +83,14 @@ void LogMessage::GenerateLogMessage() { | |||||||
|   const size_t time_buffer_size = 30; |   const size_t time_buffer_size = 30; | ||||||
|   char time_buffer[time_buffer_size]; |   char time_buffer[time_buffer_size]; | ||||||
|   strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", |   strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", | ||||||
| 	   localtime(&now_seconds)); |            localtime(&now_seconds)); | ||||||
| 
 | 
 | ||||||
|   // TODO(jeff,sanjay): Replace this with something that logs through the env.
 |   // TODO(jeff,sanjay): Replace this with something that logs through the env.
 | ||||||
|   fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder, |   fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder, | ||||||
| 	  "IWEF"[severity_], fname_, line_, str().c_str()); |           "IWEF"[severity_], fname_, line_, str().c_str()); | ||||||
| } | } | ||||||
| #endif | #endif | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| // Parse log level (int64) from environment variable (char*)
 | // Parse log level (int64) from environment variable (char*)
 | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user