diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ed715b87578..79aec061c21 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -238,7 +238,6 @@ filegroup( "//tensorflow/compiler/jit/kernels:all_files", "//tensorflow/compiler/jit/legacy_flags:all_files", "//tensorflow/compiler/jit/ops:all_files", - "//tensorflow/compiler/plugin/executor:all_files", "//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla/cc:all_files", @@ -252,6 +251,7 @@ filegroup( "//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files", + "//tensorflow/compiler/xla/service/interpreter:all_files", "//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tools:all_files", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 02e7ca64e5f..b38f59912eb 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -17,7 +17,6 @@ package_group( package( default_visibility = [ ":internal", - "//tensorflow/compiler/plugin/executor:__pkg__", ], ) @@ -33,7 +32,6 @@ cc_library( deps = [ ":xla_cpu_device", ":xla_cpu_jit", - "//tensorflow/compiler/plugin", ] + if_cuda_is_configured([ ":xla_gpu_device", ":xla_gpu_jit", @@ -99,6 +97,17 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_interpreter_device", + srcs = ["xla_interpreter_device.cc"], + deps = [ + ":xla_device", + "//tensorflow/compiler/jit/kernels:xla_device_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + ], + alwayslink = True, +) + # Internal targets below this point. cc_library( diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 354c0fabfc7..5faafe17b29 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -2,7 +2,6 @@ licenses(["notice"]) # Apache 2.0 package( default_visibility = [ - "//tensorflow/compiler/plugin/executor:__pkg__", "//tensorflow/compiler/tf2xla:internal", ], ) diff --git a/tensorflow/compiler/plugin/executor/device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc similarity index 55% rename from tensorflow/compiler/plugin/executor/device.cc rename to tensorflow/compiler/jit/xla_interpreter_device.cc index d902f9df6a5..6fc8533c0d4 100644 --- a/tensorflow/compiler/plugin/executor/device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter. + #include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -20,46 +22,47 @@ limitations under the License. namespace tensorflow { -const char* const DEVICE_XLA_EXEC = "XLA_EXEC"; -const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT"; +const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; +const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; constexpr std::array kExecAllTypes = { {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; -class XlaExaDeviceFactory : public DeviceFactory { +class XlaInterpreterDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) override; }; -Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options, - const string& name_prefix, - std::vector* devices) { - static XlaDeviceOpRegistrations* registrations = - RegisterXlaDeviceKernels(DEVICE_XLA_EXEC, DEVICE_EXEC_XLA_JIT); +Status XlaInterpreterDeviceFactory::CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector* devices) { + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels( + DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Executor", DEVICE_XLA_EXEC, 0, - DEVICE_EXEC_XLA_JIT, options, + TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, + DEVICE_INTERPRETER_XLA_JIT, options, name_prefix, &device)); devices->push_back(device.release()); return Status::OK(); } -// Set priority to be below the default priority (50), so that Executor is not -// selected as a high priority device over other default devices. -// See constructor comments for Registrar in +// Set priority to be below the default priority (50), so that Interpreter is +// not selected as a high priority device over other default devices. See +// constructor comments for Registrar in // tensorflow/core/common_runtime/device_factory.h for a list of priority for // devices. -REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 40); +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_INTERPRETER, + XlaInterpreterDeviceFactory, 40); // Kernel registrations - static bool OpFilter(KernelDef* kdef) { return true; } -REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_EXEC, XlaDeviceLaunchOp, kExecAllTypes); -REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_EXEC, kExecAllTypes); -REGISTER_XLA_BACKEND(DEVICE_EXEC_XLA_JIT, kExecAllTypes, OpFilter); +REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaDeviceLaunchOp, + kExecAllTypes); +REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); +REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); } // namespace tensorflow diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD deleted file mode 100644 index 8c2e9a7c818..00000000000 --- a/tensorflow/compiler/plugin/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Configuration file for an XLA plugin. -- please don't check in changes to this file -- to prevent changes appearing in git status, use: - git update-index --assume-unchanged tensorflow/compiler/plugin/BUILD - -To add additional devices to the XLA subsystem, add targets to the -dependency list in the 'plugin' target. For instance: - - deps = ["//tensorflow/compiler/plugin/example:plugin_lib"], -""" - -licenses(["notice"]) - -package( - default_visibility = ["//visibility:public"], -) - -cc_library( - name = "plugin", - deps = [ - "//tensorflow/compiler/plugin/executor:plugin_lib", - ], -) diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD deleted file mode 100644 index ffecd68d921..00000000000 --- a/tensorflow/compiler/plugin/executor/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -licenses(["restricted"]) - -package(default_visibility = ["//visibility:public"]) - -cc_library( - name = "plugin_lib", - srcs = glob([ - "*.cc", - ]), - hdrs = glob([ - "*.h", - ]), - deps = [ - "//tensorflow/compiler/jit:xla_device", - "//tensorflow/compiler/jit:xla_jit_headers_lib", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:xla_headers_lib", - "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:layout_assignment", - "//third_party/eigen3", - "@local_config_cuda//cuda:cuda_headers", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc deleted file mode 100644 index 32d6a0c04df..00000000000 --- a/tensorflow/compiler/plugin/executor/transfer_manager.cc +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/plugin/executor/transfer_manager.h" -#include "tensorflow/compiler/plugin/executor/platform_id.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -#include -#include -#include - -namespace sep = ::perftools::gputools::executorplugin; - -namespace xla { -namespace executorplugin { - -ExecutorTransferManager::ExecutorTransferManager() {} - -se::Platform::Id ExecutorTransferManager::PlatformId() const { - return se::executorplugin::kExecutorPlatformId; -} - -Status ExecutorTransferManager::TransferLiteralFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, Literal* literal) { - TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); - - // Tuples are a special case and contain one or more shapes inside of them to - // an arbitrary nesting depth. - if (device_shape.element_type() == TUPLE) { - *literal->mutable_shape() = literal_shape; - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - ShallowCopyTupleFromDevice(executor, source, device_shape)); - TF_RET_CHECK(element_buffers.size() == - ShapeUtil::TupleElementCount(device_shape)); - for (int64 i = 0; i < element_buffers.size(); ++i) { - const Shape& element_device_shape = device_shape.tuple_shapes(i); - const Shape& element_literal_shape = literal_shape.tuple_shapes(i); - Literal* element_literal = literal->add_tuple_literals(); - // Recursively call TransferFromDevice to copy over the data in the - // element array. - TF_RETURN_IF_ERROR(TransferLiteralFromDevice( - executor, element_buffers[i], element_device_shape, - element_literal_shape, element_literal)); - } - return Status::OK(); - } - - *literal->mutable_shape() = device_shape; - literal->Reserve(ShapeUtil::ElementsIn(device_shape)); - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, source, ShapeUtil::ByteSizeOf(device_shape), - literal->MutableInternalData())); - if (!ShapeUtil::Equal(literal_shape, device_shape)) { - *literal = std::move(*literal->Relayout(literal_shape.layout())); - } - TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); - return Status::OK(); -} - -StatusOr> -ExecutorTransferManager::ShallowCopyTupleFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - std::vector element_pointers(ShapeUtil::TupleElementCount(shape), - nullptr); - int64 tuple_size = ShapeUtil::ByteSizeOf(shape, sizeof(void*)); - auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, - element_pointers.data()); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); - } - - // Create a DeviceMemoryBase from each void* pointer. - std::vector destination; - for (int i = 0; i < element_pointers.size(); ++i) { - if (element_pointers[i] == nullptr && - !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { - return FailedPrecondition("tuple contains nullptr at element %d", i); - } - int64 buffer_size = - ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), sizeof(void*)); - destination.emplace_back(element_pointers[i], buffer_size); - } - return std::move(destination); -} - -Status ExecutorTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, - se::DeviceMemoryBase* destination) { - const Shape& shape = literal.shape(); - - if (ShapeUtil::IsTuple(literal.shape())) { - std::vector tuple_elements_on_device; - for (const Literal& tuple_element : literal.tuple_literals()) { - se::DeviceMemoryBase allocation = executor->AllocateArray( - GetByteSizeRequirement(tuple_element.shape())); - TF_RETURN_IF_ERROR( - TransferLiteralToDevice(executor, tuple_element, &allocation)); - tuple_elements_on_device.push_back(allocation.opaque()); - } - return TransferBufferToDevice( - executor, tuple_elements_on_device.size() * sizeof(void*), - tuple_elements_on_device.data(), destination); - } - - return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), - literal.InternalData(), - destination); -} - -Status ExecutorTransferManager::TransferLiteralToInfeed( - se::StreamExecutor* executor, const Literal& literal) { - const Shape& shape = literal.shape(); - VLOG(1) << "transferring literal shape to infeed: " - << ShapeUtil::HumanString(shape); - - return Status::OK(); -} - -Status ExecutorTransferManager::TransferBufferToInfeed( - se::StreamExecutor* executor, int64 size, const void* source) { - return Unimplemented("Transfer to Infeed"); -} - -Status ExecutorTransferManager::TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { - const Shape& shape = literal->shape(); - VLOG(1) << "transferring literal shape from outfeed: " - << ShapeUtil::HumanString(shape); - - return Status::OK(); -} - -Status ExecutorTransferManager::ResetDevices( - tensorflow::gtl::ArraySlice - executors) { - return Unimplemented("Device reset not supported"); -} - -int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); -} - -} // namespace executorplugin -} // namespace xla - -static std::unique_ptr CreateExecutorTransferManager() { - return xla::MakeUnique(); -} - -static bool InitModule() { - xla::TransferManager::RegisterTransferManager(sep::kExecutorPlatformId, - &CreateExecutorTransferManager); - return true; -} -static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.h b/tensorflow/compiler/plugin/executor/transfer_manager.h deleted file mode 100644 index 7a42e5a2d75..00000000000 --- a/tensorflow/compiler/plugin/executor/transfer_manager.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ - -#include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/types.h" - -#include - -namespace se = ::perftools::gputools; - -namespace xla { -namespace executorplugin { - -class ExecutorTransferManager : public TransferManager { - public: - ExecutorTransferManager(); - - ~ExecutorTransferManager() override {} - - se::Platform::Id PlatformId() const override; - - StatusOr> ShallowCopyTupleFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) override; - - Status TransferLiteralFromDevice(se::StreamExecutor* executor, - const se::DeviceMemoryBase& source, - const Shape& device_shape, - const Shape& literal_shape, - Literal* literal) override; - - Status TransferLiteralToDevice(se::StreamExecutor* executor, - const Literal& literal, - se::DeviceMemoryBase* destination) override; - - Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; - - Status TransferBufferToInfeed(se::StreamExecutor* executor, - int64 size, const void* source) override; - - Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) override; - - Status ResetDevices( - tensorflow::gtl::ArraySlice executors) override; - - int64 GetByteSizeRequirement(const Shape& shape) override; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(ExecutorTransferManager); -}; - -} // namespace executorplugin -} // namespace xla - -#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f7190c963b6..5cbf5e2fb75 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -529,6 +529,17 @@ cc_library( ], ) +cc_library( + name = "interpreter_plugin", + deps = [ + ":interpreter_transfer_manager", + ":service", + "//tensorflow/compiler/xla/service/interpreter:compiler", + "//tensorflow/compiler/xla/service/interpreter:platform", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "shaped_buffer", srcs = ["shaped_buffer.cc"], @@ -1152,6 +1163,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -1200,6 +1212,27 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) +cc_library( + name = "interpreter_transfer_manager", + srcs = ["interpreter_transfer_manager.cc"], + hdrs = ["interpreter_transfer_manager.h"], + deps = [ + ":generic_transfer_manager", + ":transfer_manager", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/interpreter:platform_id", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform transfer manager registration +) + cc_test( name = "transfer_manager_test", srcs = ["transfer_manager_test.cc"], diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 5cae034d08b..432df46eadd 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -36,19 +37,16 @@ namespace xla { GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id) : platform_id_(platform_id) { - // We currently only support kHostPlatformId for CPU and kCudaPlatformId for - // GPU. Before supporting other platforms, we need to test this transfer - // manager on them. + // We currently only support kHostPlatformId for CPU, kCudaPlatformId for + // GPU and kInterpreterPlatformId for Interpreter. Before supporting other + // platforms, we need to test this transfer manager on them. CHECK(platform_id_ == se::host::kHostPlatformId || + platform_id_ == se::interpreter::kInterpreterPlatformId || platform_id_ == se::cuda::kCudaPlatformId); } se::Platform::Id GenericTransferManager::PlatformId() const { - if (platform_id_ == se::cuda::kCudaPlatformId || - platform_id_ == se::host::kHostPlatformId) { - return platform_id_; - } - CHECK(false) << "GenericTransferManager::platform_id_ is invalid"; + return platform_id_; } Status GenericTransferManager::TransferLiteralFromDevice( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 48c061d28e5..993312fef9d 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -75,7 +75,7 @@ class GenericTransferManager : public TransferManager { private: // The platform this transfer manager targets. - perftools::gputools::Platform::Id platform_id_; + const perftools::gputools::Platform::Id platform_id_; TF_DISALLOW_COPY_AND_ASSIGN(GenericTransferManager); }; diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD new file mode 100644 index 00000000000..8e59f496f50 --- /dev/null +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -0,0 +1,113 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":executable", + ":platform_id", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_constant_folding", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", + "//tensorflow/compiler/xla/service:inliner", + "//tensorflow/compiler/xla/service:layout_assignment", + "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor", + ], + alwayslink = True, # Contains compiler registration +) + +cc_library( + name = "platform_id", + srcs = ["platform_id.cc"], + hdrs = ["platform_id.h"], + deps = [ + "//tensorflow/core:stream_executor_headers_lib", + "@nsync//:nsync_headers", + "@protobuf_archive//:protobuf_headers", + "@protobuf_archive//:protoc_lib", + ], +) + +cc_library( + name = "executable", + srcs = ["executable.cc"], + hdrs = ["executable.h"], + deps = [ + ":executor", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "//tensorflow/compiler/xla/service:hlo_evaluator", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "platform", + srcs = ["platform.cc"], + hdrs = ["platform.h"], + deps = [ + ":executor", + ":platform_id", + "//tensorflow/core:stream_executor_headers_lib", + ], + alwayslink = True, # Registers itself with the MultiPlatformManager. +) + +cc_library( + name = "executor", + srcs = ["executor.cc"], + hdrs = ["executor.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/compiler/xla/service/interpreter/README.md b/tensorflow/compiler/xla/service/interpreter/README.md new file mode 100644 index 00000000000..4c19a1b916d --- /dev/null +++ b/tensorflow/compiler/xla/service/interpreter/README.md @@ -0,0 +1,19 @@ +# XLA Interpreter Backend + +The XLA Interpreter backend operates at HLO-level by ingesting a HloModule and +evaluating the result of the HLO graph directly with HloEvaluator, without +lowering it further (to LLVM IR for example) before execution as other backends +(CPU and GPU for example) do. + +Its key componenets are: + +* [`InterpreterCompiler`] despite the inherited naming of "compiler", all + `InterpreterCompiler` really does is the following: + 1. Runs certain HLO optimization passes on the given HLO graph. + 2. Generates an `InterpreterExecutable` from the optimized HLO graph. + 3. Registers itself in the global compiler factory registry. +* [`InterpreterExecutable`]: responsible for running input HLO graph through + the `HloEvaluator`, allocating output buffer and finally copying evaluated + Literal result over. +* [`HloEvaluator`]: traverses a HLO graph and evaluates each node in DFS + ordering along the way. diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc similarity index 60% rename from tensorflow/compiler/plugin/executor/compiler.cc rename to tensorflow/compiler/xla/service/interpreter/compiler.cc index 77193f06c4b..c8d02834f43 100644 --- a/tensorflow/compiler/plugin/executor/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include "tensorflow/compiler/xla/service/interpreter/compiler.h" -#include "tensorflow/compiler/plugin/executor/compiler.h" -#include "tensorflow/compiler/plugin/executor/executable.h" +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -28,26 +29,27 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/stream_executor/lib/initialize.h" -#include "tensorflow/stream_executor/lib/strcat.h" +#include "tensorflow/core/platform/types.h" namespace xla { -namespace executorplugin { +namespace interpreter { namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::executorplugin; +namespace sep = ::perftools::gputools::interpreter; /* - * Run optimization passes on the module. The graph is transformed by - * each pass in the optimization pipeline. The service subdirectory + * Run optimization passes on the module. The graph is transformed by + * each pass in the optimization pipeline. The service subdirectory * contains useful optimization passes. */ -Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) { - HloPassPipeline pipeline("Executor"); +Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { + HloPassPipeline pipeline("Interpreter"); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(false); @@ -65,9 +67,8 @@ Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) { return pipeline.Run(hlo_module).status(); } -StatusOr> ExecutorCompiler::Compile( - std::unique_ptr hlo_module, - se::StreamExecutor* stream_exec) { +StatusOr> InterpreterCompiler::Compile( + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); VLOG(1) << "Generate graph " << hlo_module->name(); @@ -75,53 +76,54 @@ StatusOr> ExecutorCompiler::Compile( TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); // Typically you would visit the HLO graph, building up a compiled equivalent - // In this case we are using an Hlo evaluator at execution time, so we don't + // In this case we are using an HloEvaluator at execution time, so we don't // need to compile anything - // Create executable from only the Hlo module - std::unique_ptr executable; - executable.reset(new ExecutorExecutable(std::move(hlo_module))); + // Create executable from only the Hlo module. + std::unique_ptr executable = + xla::MakeUnique(std::move(hlo_module)); return std::move(executable); } -StatusOr>> ExecutorCompiler::Compile( - std::vector> hlo_modules, - std::vector stream_execs) { - +StatusOr>> InterpreterCompiler::Compile( + std::vector> /*hlo_modules*/, + std::vector /*stream_execs*/) { return tensorflow::errors::Unimplemented( - "Compilation of multiple HLO modules is not supported on Executor."); + "Compilation of multiple HLO modules is not supported on Interpreter."); } StatusOr>> -ExecutorCompiler::CompileAheadOfTime( +InterpreterCompiler::CompileAheadOfTime( std::vector> hlo_modules, const AotCompilationOptions& aot_options) { - return tensorflow::errors::InvalidArgument( - "AOT compilation not supported on Executor"); + "AOT compilation not supported on Interpreter"); } -se::Platform::Id ExecutorCompiler::PlatformId() const { - return sep::kExecutorPlatformId; +se::Platform::Id InterpreterCompiler::PlatformId() const { + return sep::kInterpreterPlatformId; } -HloCostAnalysis::ShapeSizeFunction -ExecutorCompiler::ShapeSizeBytesFunction() const { - return ExecutorExecutable::ShapeSizeBytes; +HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() + const { + return InterpreterExecutable::ShapeSizeBytes; } static std::unique_ptr CreateComputationPlacer() { return xla::MakeUnique(); } -REGISTER_MODULE_INITIALIZER(executor_compiler, { - xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() { - return xla::MakeUnique(); +static bool InitModule() { + xla::Compiler::RegisterCompilerFactory(sep::kInterpreterPlatformId, []() { + return xla::MakeUnique(); }); - xla::ComputationPlacer::RegisterComputationPlacer(sep::kExecutorPlatformId, + xla::ComputationPlacer::RegisterComputationPlacer(sep::kInterpreterPlatformId, &CreateComputationPlacer); -}); + return true; +} -} // namespace executorplugin +static bool module_initialized = InitModule(); + +} // namespace interpreter } // namespace xla diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h similarity index 53% rename from tensorflow/compiler/plugin/executor/compiler.h rename to tensorflow/compiler/xla/service/interpreter/compiler.h index d318eefc49f..13db38ab60a 100644 --- a/tensorflow/compiler/plugin/executor/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -13,38 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ -#define TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_COMPILER_H_ #include +#include #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" - -#include "tensorflow/compiler/plugin/executor/platform_id.h" +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { -namespace executorplugin { +namespace interpreter { -class ExecutorCompiler : public Compiler { +// Despite the inherited "compiler" naming, InterpreterCompiler does not +// perform any lowering as other backends do. It operates at HLO-level for +// and is responsible for generating an InterpreterExecutable. +// Refer to interpreter/README.md for more. +class InterpreterCompiler : public Compiler { public: - ExecutorCompiler() {} - ~ExecutorCompiler() override {} + InterpreterCompiler() {} + ~InterpreterCompiler() override {} StatusOr> Compile( - std::unique_ptr hlo_module, + std::unique_ptr hlo_modules, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> hlo_module, + std::vector> hlo_modules, std::vector stream_exec) override; StatusOr>> - CompileAheadOfTime( - std::vector> module, - const AotCompilationOptions& options) override; + CompileAheadOfTime(std::vector> hlo_modules, + const AotCompilationOptions& aot_options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; @@ -53,10 +62,10 @@ class ExecutorCompiler : public Compiler { private: Status RunHloOptimization(HloModule* hlo_module); - TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler); + TF_DISALLOW_COPY_AND_ASSIGN(InterpreterCompiler); }; -} // namespace executorplugin +} // namespace interpreter } // namespace xla -#endif // TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_COMPILER_H_ diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc similarity index 70% rename from tensorflow/compiler/plugin/executor/executable.cc rename to tensorflow/compiler/xla/service/interpreter/executable.cc index e866dfef059..989fc4e0313 100644 --- a/tensorflow/compiler/plugin/executor/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -13,25 +13,41 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/plugin/executor/executable.h" -#include "tensorflow/compiler/plugin/executor/executor.h" +#include "tensorflow/compiler/xla/service/interpreter/executable.h" + +#include +#include +#include +#include +#include + #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { -namespace executorplugin { +namespace interpreter { namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::executorplugin; +namespace sep = ::perftools::gputools::interpreter; -ExecutorExecutable::ExecutorExecutable(std::unique_ptr hlo_module) +InterpreterExecutable::InterpreterExecutable( + std::unique_ptr hlo_module) : Executable(std::move(hlo_module)) {} -ExecutorExecutable::~ExecutorExecutable() {} +InterpreterExecutable::~InterpreterExecutable() {} static se::DeviceMemoryBase AllocateSingleOutput( - sep::ExecutorExecutor* executor, const Literal& literal) { + sep::InterpreterExecutor* executor, const Literal& literal) { int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); void* buf = executor->Allocate(size); const void* src = literal.InternalData(); @@ -40,7 +56,7 @@ static se::DeviceMemoryBase AllocateSingleOutput( } static se::DeviceMemoryBase AllocateOutputBuffer( - sep::ExecutorExecutor* executor, const Literal& literal) { + sep::InterpreterExecutor* executor, const Literal& literal) { const Shape& shape = literal.shape(); if (shape.element_type() != xla::TUPLE) { return AllocateSingleOutput(executor, literal); @@ -58,7 +74,7 @@ static se::DeviceMemoryBase AllocateOutputBuffer( } } -StatusOr ExecutorExecutable::ExecuteOnStream( +StatusOr InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { @@ -82,7 +98,7 @@ StatusOr ExecutorExecutable::ExecuteOnStream( // Create the arguments as an vector of XLA literals std::vector> arg_literals; std::vector arg_literals_ptrs; - for (int64 p = 0; p < computation->num_parameters(); p++) { + for (int64 p = 0; p < computation->num_parameters(); ++p) { // Create the input literal for the parameter HloInstruction* param = computation->parameter_instruction(p); arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); @@ -94,18 +110,18 @@ StatusOr ExecutorExecutable::ExecuteOnStream( ShapeUtil::ByteSizeOf(param->shape())); } - // Execute the graph using the evaluator + // Execute the graph using the HloEvaluator. HloEvaluator evaluator; TF_ASSIGN_OR_RETURN(std::unique_ptr output, evaluator.Evaluate(*computation, arg_literals_ptrs)); // Copy the result into the return buffer perftools::gputools::StreamExecutor* executor(stream->parent()); - sep::ExecutorExecutor* executorExecutor( - static_cast(executor->implementation())); + sep::InterpreterExecutor* interpreter_executor( + static_cast(executor->implementation())); se::DeviceMemoryBase ret = - AllocateOutputBuffer(executorExecutor, *(output.get())); + AllocateOutputBuffer(interpreter_executor, *(output.get())); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -118,32 +134,32 @@ StatusOr ExecutorExecutable::ExecuteOnStream( return ret; } -StatusOr> ExecutorExecutable::ExecuteOnStream( +StatusOr> InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { return tensorflow::errors::Unimplemented( - "ExecuteOnStream is not yet supported on Executor."); + "ExecuteOnStream is not yet supported on Interpreter."); } -StatusOr ExecutorExecutable::ExecuteAsyncOnStream( +StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { return tensorflow::errors::Unimplemented( - "ExecuteAsyncOnStream is not yet supported on Executor."); + "ExecuteAsyncOnStream is not yet supported on Interpreter."); } -/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) { +/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { if (ShapeUtil::IsOpaque(shape)) { return sizeof(void*); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } -std::unique_ptr ExecutorExecutable::CreateCostAnalysis() +std::unique_ptr InterpreterExecutable::CreateCostAnalysis() const { return MakeUnique(ShapeSizeBytes); } -} // namespace executorplugin +} // namespace interpreter } // namespace xla diff --git a/tensorflow/compiler/plugin/executor/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h similarity index 60% rename from tensorflow/compiler/plugin/executor/executable.h rename to tensorflow/compiler/xla/service/interpreter/executable.h index 8572aba43d9..2881d6697e2 100644 --- a/tensorflow/compiler/plugin/executor/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -13,29 +13,35 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ -#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_H_ -#include #include -#include -#include -#include #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" - -#include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" namespace xla { -namespace executorplugin { +namespace interpreter { -class ExecutorExecutable : public Executable { +// Responsible for running a HLO graph through the HloEvaluator and output +// buffer allocation. Refer to interpreter/README.md for more. +class InterpreterExecutable : public Executable { public: - ExecutorExecutable(std::unique_ptr hlo_module); - ~ExecutorExecutable() override; + InterpreterExecutable(std::unique_ptr hlo_module); + ~InterpreterExecutable() override; StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, @@ -58,10 +64,10 @@ class ExecutorExecutable : public Executable { std::unique_ptr CreateCostAnalysis() const override; private: - TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable); + TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; -} // namespace executorplugin +} // namespace interpreter } // namespace xla -#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_H_ diff --git a/tensorflow/compiler/plugin/executor/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc similarity index 54% rename from tensorflow/compiler/plugin/executor/executor.cc rename to tensorflow/compiler/xla/service/interpreter/executor.cc index 908b996bc95..0bb3259ef43 100644 --- a/tensorflow/compiler/plugin/executor/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -13,117 +13,110 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/plugin/executor/executor.h" +#include "tensorflow/compiler/xla/service/interpreter/executor.h" -#include -#include +#include -#include "tensorflow/compiler/plugin/executor/platform_id.h" #include "tensorflow/compiler/xla/status_macros.h" namespace perftools { namespace gputools { -namespace executorplugin { +namespace interpreter { host::HostStream *AsExecutorStream(Stream *stream) { DCHECK(stream != nullptr); return dynamic_cast(stream->implementation()); } -ExecutorExecutor::ExecutorExecutor(const PluginConfig &plugin_config) +InterpreterExecutor::InterpreterExecutor(const PluginConfig &plugin_config) : plugin_config_(plugin_config) {} -ExecutorExecutor::~ExecutorExecutor() {} +InterpreterExecutor::~InterpreterExecutor() {} -void *ExecutorExecutor::Allocate(uint64 size) { return new char[size]; } +void *InterpreterExecutor::Allocate(uint64 size) { return new char[size]; } -void *ExecutorExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, - uint64 offset_bytes, - uint64 size_bytes) { +void *InterpreterExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, + uint64 offset_bytes, + uint64 /*size_bytes*/) { return parent + offset_bytes; } -void ExecutorExecutor::Deallocate(DeviceMemoryBase *mem) { +void InterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { if (!mem->is_sub_buffer()) { delete[] static_cast(mem->opaque()); } } -bool ExecutorExecutor::Memcpy(Stream *stream, void *host_dst, - const DeviceMemoryBase &dev_src, uint64 size) { +bool InterpreterExecutor::Memcpy(Stream *stream, void *host_dst, + const DeviceMemoryBase &dev_src, uint64 size) { AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); return true; } -bool ExecutorExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, - const void *host_src, uint64 size) { +bool InterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, + const void *host_src, uint64 size) { AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); return true; } -port::Status ExecutorExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, - const void *host_src, - uint64 size) { +port::Status InterpreterExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, + const void *host_src, + uint64 size) { memcpy(dev_dst->opaque(), host_src, size); return port::Status::OK(); } -port::Status ExecutorExecutor::SynchronousMemcpy(void *host_dst, - const DeviceMemoryBase &dev_src, - uint64 size) { +port::Status InterpreterExecutor::SynchronousMemcpy( + void *host_dst, const DeviceMemoryBase &dev_src, uint64 size) { memcpy(host_dst, dev_src.opaque(), size); return port::Status::OK(); } -bool ExecutorExecutor::HostCallback(Stream *stream, - std::function callback) { +bool InterpreterExecutor::HostCallback(Stream *stream, + std::function callback) { AsExecutorStream(stream)->EnqueueTask(callback); return true; } -bool ExecutorExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { +bool InterpreterExecutor::CreateStreamDependency(Stream *dependent, + Stream *other) { AsExecutorStream(dependent)->EnqueueTask( [other]() { other->BlockHostUntilDone(); }); AsExecutorStream(dependent)->BlockUntilDone(); return true; } -bool ExecutorExecutor::StartTimer(Stream *stream, Timer *timer) { +bool InterpreterExecutor::StartTimer(Stream *stream, Timer *timer) { dynamic_cast(timer->implementation())->Start(stream); return true; } -bool ExecutorExecutor::StopTimer(Stream *stream, Timer *timer) { +bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { dynamic_cast(timer->implementation())->Stop(stream); return true; } -bool ExecutorExecutor::BlockHostUntilDone(Stream *stream) { +bool InterpreterExecutor::BlockHostUntilDone(Stream *stream) { AsExecutorStream(stream)->BlockUntilDone(); return true; } -DeviceDescription *ExecutorExecutor::PopulateDeviceDescription() const { +DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { internal::DeviceDescriptionBuilder builder; builder.set_device_address_bits(64); - builder.set_name("Executor"); - builder.set_device_vendor("VectorName"); - builder.set_platform_version("1.0"); - builder.set_driver_version("1.0"); - builder.set_runtime_version("1.0"); - builder.set_pci_bus_id("1"); + builder.set_name("Interpreter"); builder.set_device_memory_size(static_cast(4) * 1024 * 1024 * 1024); builder.set_clock_rate_ghz(static_cast(CLOCKS_PER_SEC) / 1e9); return builder.Build().release(); } -} // namespace executorplugin +} // namespace interpreter } // namespace gputools } // namespace perftools diff --git a/tensorflow/compiler/plugin/executor/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h similarity index 84% rename from tensorflow/compiler/plugin/executor/executor.h rename to tensorflow/compiler/xla/service/interpreter/executor.h index 32fdb157e48..c59b2ccb150 100644 --- a/tensorflow/compiler/plugin/executor/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -13,38 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Declares the ExecutorExecutor class, which is a CPU-only implementation of +// Declares the InterpreterExecutor class, which is a CPU-only implementation of // the StreamExecutor interface. For now, this is used for testing and to // examine the performance of host-based StreamExecutor code. -#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ -#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ -#include "tensorflow/stream_executor/host/host_stream.h" -#include "tensorflow/stream_executor/host/host_timer.h" +#include +#include #include "tensorflow/compiler/xla/shape_util.h" - +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/blas.h" -#include "tensorflow/stream_executor/lib/error.h" -#include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/device_description.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/device_options.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/host/host_stream.h" +#include "tensorflow/stream_executor/host/host_timer.h" +#include "tensorflow/stream_executor/kernel.h" +#include "tensorflow/stream_executor/kernel_spec.h" +#include "tensorflow/stream_executor/launch_dim.h" +#include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/rng.h" +#include "tensorflow/stream_executor/shared_memory_config.h" +#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor_internal.h" - -#include -#include +#include "tensorflow/stream_executor/timer.h" namespace perftools { namespace gputools { -namespace executorplugin { +namespace interpreter { using Args = tensorflow::gtl::ArraySlice; -class ExecutorExecutor : public internal::StreamExecutorInterface { +class InterpreterExecutor : public internal::StreamExecutorInterface { public: - explicit ExecutorExecutor(const PluginConfig &plugin_config); - ~ExecutorExecutor() override; + explicit InterpreterExecutor(const PluginConfig &plugin_config); + ~InterpreterExecutor() override; port::Status Init(int device_ordinal, DeviceOptions device_options) override { return port::Status::OK(); @@ -194,9 +203,6 @@ class ExecutorExecutor : public internal::StreamExecutorInterface { return std::unique_ptr(new host::HostTimer()); } - port::StatusOr ExecuteGraph(const xla::Shape &shape, - Args args); - private: DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); @@ -206,8 +212,8 @@ class ExecutorExecutor : public internal::StreamExecutorInterface { const PluginConfig plugin_config_; }; -} // namespace executorplugin +} // namespace interpreter } // namespace gputools } // namespace perftools -#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ diff --git a/tensorflow/compiler/plugin/executor/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc similarity index 61% rename from tensorflow/compiler/plugin/executor/platform.cc rename to tensorflow/compiler/xla/service/interpreter/platform.cc index 404e1c3da34..a60e7fc59f7 100644 --- a/tensorflow/compiler/plugin/executor/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -13,37 +13,39 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/plugin/executor/platform.h" -#include "tensorflow/compiler/plugin/executor/executor.h" -#include "tensorflow/compiler/plugin/executor/platform_id.h" +#include "tensorflow/compiler/xla/service/interpreter/platform.h" -#include "tensorflow/stream_executor/lib/error.h" +#include + +#include "tensorflow/compiler/xla/service/interpreter/executor.h" +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" +#include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/lib/stringprintf.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::executorplugin; +namespace sep = ::perftools::gputools::interpreter; namespace perftools { namespace gputools { -namespace executorplugin { +namespace interpreter { -PLATFORM_DEFINE_ID(kExecutorPlatformId); +InterpreterPlatform::InterpreterPlatform() : name_("Interpreter") {} -ExecutorPlatform::ExecutorPlatform() : name_("Executor") {} +InterpreterPlatform::~InterpreterPlatform() {} -ExecutorPlatform::~ExecutorPlatform() {} +Platform::Id InterpreterPlatform::id() const { return kInterpreterPlatformId; } -Platform::Id ExecutorPlatform::id() const { return kExecutorPlatformId; } +int InterpreterPlatform::VisibleDeviceCount() const { return 1; } -int ExecutorPlatform::VisibleDeviceCount() const { return 1; } +const string& InterpreterPlatform::Name() const { return name_; } -const string& ExecutorPlatform::Name() const { return name_; } - -port::StatusOr ExecutorPlatform::ExecutorForDevice( +port::StatusOr InterpreterPlatform::ExecutorForDevice( int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; @@ -53,7 +55,7 @@ port::StatusOr ExecutorPlatform::ExecutorForDevice( } port::StatusOr -ExecutorPlatform::ExecutorForDeviceWithPluginConfig( +InterpreterPlatform::ExecutorForDeviceWithPluginConfig( int device_ordinal, const PluginConfig& plugin_config) { StreamExecutorConfig config; config.ordinal = device_ordinal; @@ -62,16 +64,16 @@ ExecutorPlatform::ExecutorForDeviceWithPluginConfig( return GetExecutor(config); } -port::StatusOr ExecutorPlatform::GetExecutor( +port::StatusOr InterpreterPlatform::GetExecutor( const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( config, [&]() { return GetUncachedExecutor(config); }); } port::StatusOr> -ExecutorPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { +InterpreterPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = port::MakeUnique( - this, port::MakeUnique(config.plugin_config)); + this, port::MakeUnique(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ @@ -84,27 +86,30 @@ ExecutorPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { return std::move(executor); } -void ExecutorPlatform::RegisterTraceListener( +void InterpreterPlatform::RegisterTraceListener( std::unique_ptr listener) { LOG(FATAL) << "not yet implemented: register executor trace listener"; } -void ExecutorPlatform::UnregisterTraceListener(TraceListener* listener) { +void InterpreterPlatform::UnregisterTraceListener(TraceListener* listener) { LOG(FATAL) << "not yet implemented: unregister executor trace listener"; } -static void InitializeExecutorPlatform() { - std::unique_ptr platform(new sep::ExecutorPlatform); +static void InitializeInterpreterPlatform() { + std::unique_ptr platform(new sep::InterpreterPlatform); SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform))); } -} // namespace executorplugin +} // namespace interpreter } // namespace gputools } // namespace perftools -REGISTER_MODULE_INITIALIZER(executor_platform, sep::InitializeExecutorPlatform()); +REGISTER_MODULE_INITIALIZER(interpreter_platform, + sep::InitializeInterpreterPlatform()); DECLARE_MODULE_INITIALIZER(multi_platform_manager); + // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. -REGISTER_MODULE_INITIALIZER_SEQUENCE(executor_platform, multi_platform_manager); +REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, + multi_platform_manager); diff --git a/tensorflow/compiler/plugin/executor/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h similarity index 63% rename from tensorflow/compiler/plugin/executor/platform.h rename to tensorflow/compiler/xla/service/interpreter/platform.h index 624bcd5a4eb..c66ddb907d1 100644 --- a/tensorflow/compiler/plugin/executor/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -12,38 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ -#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_H_ #include #include -#include #include "tensorflow/stream_executor/executor_cache.h" -#include "tensorflow/stream_executor/lib/statusor.h" -#include "tensorflow/stream_executor/multi_platform_manager.h" -#include "tensorflow/stream_executor/platform.h" -#include "tensorflow/stream_executor/platform/mutex.h" -#include "tensorflow/stream_executor/platform/port.h" -#include "tensorflow/stream_executor/platform/thread_annotations.h" -#include "tensorflow/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/stream_executor/plugin.h" +#include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/trace_listener.h" namespace perftools { namespace gputools { -namespace executorplugin { +namespace interpreter { -class ExecutorPlatform : public Platform { +class InterpreterPlatform : public Platform { public: - ExecutorPlatform(); - ~ExecutorPlatform() override; + InterpreterPlatform(); + ~InterpreterPlatform() override; Platform::Id id() const override; - // Device count is less clear-cut for CPUs than accelerators. This call - // currently returns the number of thread units in the host, as reported by - // base::NumCPUs(). int VisibleDeviceCount() const override; const string& Name() const override; @@ -70,11 +60,11 @@ class ExecutorPlatform : public Platform { // Cache of created StreamExecutors. ExecutorCache executor_cache_; - SE_DISALLOW_COPY_AND_ASSIGN(ExecutorPlatform); + SE_DISALLOW_COPY_AND_ASSIGN(InterpreterPlatform); }; -} // namespace executorplugin +} // namespace interpreter } // namespace gputools } // namespace perftools -#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/platform_id.cc b/tensorflow/compiler/xla/service/interpreter/platform_id.cc new file mode 100644 index 00000000000..1a0373cf86e --- /dev/null +++ b/tensorflow/compiler/xla/service/interpreter/platform_id.cc @@ -0,0 +1,25 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" + +namespace perftools { +namespace gputools { +namespace interpreter { + +PLATFORM_DEFINE_ID(kInterpreterPlatformId); + +} // namespace interpreter +} // namespace gputools +} // namespace perftools diff --git a/tensorflow/compiler/plugin/executor/platform_id.h b/tensorflow/compiler/xla/service/interpreter/platform_id.h similarity index 72% rename from tensorflow/compiler/plugin/executor/platform_id.h rename to tensorflow/compiler/xla/service/interpreter/platform_id.h index 8d2b29a3e4e..905efef1690 100644 --- a/tensorflow/compiler/plugin/executor/platform_id.h +++ b/tensorflow/compiler/xla/service/interpreter/platform_id.h @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ -#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_ID_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_ID_H_ #include "tensorflow/stream_executor/platform.h" namespace perftools { namespace gputools { -namespace executorplugin { +namespace interpreter { -extern const Platform::Id kExecutorPlatformId; +extern const Platform::Id kInterpreterPlatformId; -} // namespace executorplugin +} // namespace interpreter } // namespace gputools } // namespace perftools -#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_ID_H_ diff --git a/tensorflow/compiler/xla/service/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter_transfer_manager.cc new file mode 100644 index 00000000000..1864dcdf036 --- /dev/null +++ b/tensorflow/compiler/xla/service/interpreter_transfer_manager.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/interpreter_transfer_manager.h" + +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" + +namespace sei = ::perftools::gputools::interpreter; + +namespace xla { + +InterpreterTransferManager::InterpreterTransferManager() + : GenericTransferManager(sei::kInterpreterPlatformId) {} + +} // namespace xla + +static std::unique_ptr +CreateInterpreterTransferManager() { + return xla::MakeUnique(); +} + +static bool InitModule() { + xla::TransferManager::RegisterTransferManager( + sei::kInterpreterPlatformId, &CreateInterpreterTransferManager); + return true; +} + +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/interpreter_transfer_manager.h b/tensorflow/compiler/xla/service/interpreter_transfer_manager.h new file mode 100644 index 00000000000..2b44f308218 --- /dev/null +++ b/tensorflow/compiler/xla/service/interpreter_transfer_manager.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ + +#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// An implementation of the XLA GenericTransferManager for interpreter backend. +class InterpreterTransferManager : public GenericTransferManager { + public: + InterpreterTransferManager(); + ~InterpreterTransferManager() override = default; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(InterpreterTransferManager); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index da39ba3ffc3..fd3cda42da5 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -107,15 +107,10 @@ cc_binary( ) cc_binary( - name = "replay_computation_hlo_evaluator", + name = "replay_computation_interpreter", deps = [ ":replay_computation_library", - "//tensorflow/compiler/plugin/executor:plugin_lib", - # TODO: This dependency is a workaround for linking error with clang. - # Without it, linker complains about missing symbols from - # 'xla_device_launch_op'. This dependency should be propagated from - # plugin_lib instead, but no targets other than this break without it. - "//tensorflow/compiler/jit", + "//tensorflow/compiler/xla/service:interpreter_plugin", ], )