[TF:XLA] Create Interpreter backend from the Executor backend.

- Move plugin/executor to xla/service/interpreter/
- Remove executor's TransferManager, and use GenericTransferManager instead.
- Renamings and minor fixes.

PiperOrigin-RevId: 169160056
This commit is contained in:
Kay Zhu 2017-09-18 15:49:35 -07:00 committed by TensorFlower Gardener
parent de724b1ac4
commit 7de939bb74
26 changed files with 539 additions and 576 deletions

View File

@ -238,7 +238,6 @@ filegroup(
"//tensorflow/compiler/jit/kernels:all_files", "//tensorflow/compiler/jit/kernels:all_files",
"//tensorflow/compiler/jit/legacy_flags:all_files", "//tensorflow/compiler/jit/legacy_flags:all_files",
"//tensorflow/compiler/jit/ops:all_files", "//tensorflow/compiler/jit/ops:all_files",
"//tensorflow/compiler/plugin/executor:all_files",
"//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tests:all_files",
"//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla:all_files",
"//tensorflow/compiler/tf2xla/cc:all_files", "//tensorflow/compiler/tf2xla/cc:all_files",
@ -252,6 +251,7 @@ filegroup(
"//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/cpu:all_files",
"//tensorflow/compiler/xla/service/gpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files",
"//tensorflow/compiler/xla/service/interpreter:all_files",
"//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/service/llvm_ir:all_files",
"//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tests:all_files",
"//tensorflow/compiler/xla/tools:all_files", "//tensorflow/compiler/xla/tools:all_files",

View File

@ -17,7 +17,6 @@ package_group(
package( package(
default_visibility = [ default_visibility = [
":internal", ":internal",
"//tensorflow/compiler/plugin/executor:__pkg__",
], ],
) )
@ -33,7 +32,6 @@ cc_library(
deps = [ deps = [
":xla_cpu_device", ":xla_cpu_device",
":xla_cpu_jit", ":xla_cpu_jit",
"//tensorflow/compiler/plugin",
] + if_cuda_is_configured([ ] + if_cuda_is_configured([
":xla_gpu_device", ":xla_gpu_device",
":xla_gpu_jit", ":xla_gpu_jit",
@ -99,6 +97,17 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "xla_interpreter_device",
srcs = ["xla_interpreter_device.cc"],
deps = [
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_device_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
],
alwayslink = True,
)
# Internal targets below this point. # Internal targets below this point.
cc_library( cc_library(

View File

@ -2,7 +2,6 @@ licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = [ default_visibility = [
"//tensorflow/compiler/plugin/executor:__pkg__",
"//tensorflow/compiler/tf2xla:internal", "//tensorflow/compiler/tf2xla:internal",
], ],
) )

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h" #include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h"
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/jit/xla_device_ops.h"
@ -20,46 +22,47 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
const char* const DEVICE_XLA_EXEC = "XLA_EXEC"; const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER";
const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT";
constexpr std::array<DataType, 5> kExecAllTypes = { constexpr std::array<DataType, 5> kExecAllTypes = {
{DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}};
class XlaExaDeviceFactory : public DeviceFactory { class XlaInterpreterDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override; std::vector<Device*>* devices) override;
}; };
Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options, Status XlaInterpreterDeviceFactory::CreateDevices(
const string& name_prefix, const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<Device*>* devices) {
static XlaDeviceOpRegistrations* registrations = static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
RegisterXlaDeviceKernels(DEVICE_XLA_EXEC, DEVICE_EXEC_XLA_JIT); DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
(void)registrations; (void)registrations;
std::unique_ptr<XlaDevice> device; std::unique_ptr<XlaDevice> device;
TF_RETURN_IF_ERROR(XlaDevice::Create("Executor", DEVICE_XLA_EXEC, 0, TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0,
DEVICE_EXEC_XLA_JIT, options, DEVICE_INTERPRETER_XLA_JIT, options,
name_prefix, &device)); name_prefix, &device));
devices->push_back(device.release()); devices->push_back(device.release());
return Status::OK(); return Status::OK();
} }
// Set priority to be below the default priority (50), so that Executor is not // Set priority to be below the default priority (50), so that Interpreter is
// selected as a high priority device over other default devices. // not selected as a high priority device over other default devices. See
// See constructor comments for Registrar in // constructor comments for Registrar in
// tensorflow/core/common_runtime/device_factory.h for a list of priority for // tensorflow/core/common_runtime/device_factory.h for a list of priority for
// devices. // devices.
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 40); REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_INTERPRETER,
XlaInterpreterDeviceFactory, 40);
// Kernel registrations // Kernel registrations
static bool OpFilter(KernelDef* kdef) { return true; } static bool OpFilter(KernelDef* kdef) { return true; }
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_EXEC, XlaDeviceLaunchOp, kExecAllTypes); REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaDeviceLaunchOp,
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_EXEC, kExecAllTypes); kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_EXEC_XLA_JIT, kExecAllTypes, OpFilter); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
} // namespace tensorflow } // namespace tensorflow

View File

@ -1,38 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration file for an XLA plugin.
- please don't check in changes to this file
- to prevent changes appearing in git status, use:
git update-index --assume-unchanged tensorflow/compiler/plugin/BUILD
To add additional devices to the XLA subsystem, add targets to the
dependency list in the 'plugin' target. For instance:
deps = ["//tensorflow/compiler/plugin/example:plugin_lib"],
"""
licenses(["notice"])
package(
default_visibility = ["//visibility:public"],
)
cc_library(
name = "plugin",
deps = [
"//tensorflow/compiler/plugin/executor:plugin_lib",
],
)

View File

@ -1,37 +0,0 @@
licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "plugin_lib",
srcs = glob([
"*.cc",
]),
hdrs = glob([
"*.h",
]),
deps = [
"//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/jit:xla_jit_headers_lib",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:xla_headers_lib",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:layout_assignment",
"//third_party/eigen3",
"@local_config_cuda//cuda:cuda_headers",
"@protobuf_archive//:protobuf_headers",
],
alwayslink = 1,
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)

View File

@ -1,186 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/plugin/executor/transfer_manager.h"
#include "tensorflow/compiler/plugin/executor/platform_id.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include <string>
#include <utility>
#include <vector>
namespace sep = ::perftools::gputools::executorplugin;
namespace xla {
namespace executorplugin {
ExecutorTransferManager::ExecutorTransferManager() {}
se::Platform::Id ExecutorTransferManager::PlatformId() const {
return se::executorplugin::kExecutorPlatformId;
}
Status ExecutorTransferManager::TransferLiteralFromDevice(
se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
const Shape& device_shape, const Shape& literal_shape, Literal* literal) {
TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape));
// Tuples are a special case and contain one or more shapes inside of them to
// an arbitrary nesting depth.
if (device_shape.element_type() == TUPLE) {
*literal->mutable_shape() = literal_shape;
TF_ASSIGN_OR_RETURN(
std::vector<se::DeviceMemoryBase> element_buffers,
ShallowCopyTupleFromDevice(executor, source, device_shape));
TF_RET_CHECK(element_buffers.size() ==
ShapeUtil::TupleElementCount(device_shape));
for (int64 i = 0; i < element_buffers.size(); ++i) {
const Shape& element_device_shape = device_shape.tuple_shapes(i);
const Shape& element_literal_shape = literal_shape.tuple_shapes(i);
Literal* element_literal = literal->add_tuple_literals();
// Recursively call TransferFromDevice to copy over the data in the
// element array.
TF_RETURN_IF_ERROR(TransferLiteralFromDevice(
executor, element_buffers[i], element_device_shape,
element_literal_shape, element_literal));
}
return Status::OK();
}
*literal->mutable_shape() = device_shape;
literal->Reserve(ShapeUtil::ElementsIn(device_shape));
TF_RETURN_IF_ERROR(TransferBufferFromDevice(
executor, source, ShapeUtil::ByteSizeOf(device_shape),
literal->MutableInternalData()));
if (!ShapeUtil::Equal(literal_shape, device_shape)) {
*literal = std::move(*literal->Relayout(literal_shape.layout()));
}
TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape()));
return Status::OK();
}
StatusOr<std::vector<se::DeviceMemoryBase>>
ExecutorTransferManager::ShallowCopyTupleFromDevice(
se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
const Shape& shape) {
TF_RET_CHECK(ShapeUtil::IsTuple(shape));
std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape),
nullptr);
int64 tuple_size = ShapeUtil::ByteSizeOf(shape, sizeof(void*));
auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size,
element_pointers.data());
if (!copy_status.ok()) {
return AddStatus(
Status(static_cast<tensorflow::error::Code>(copy_status.code()),
copy_status.error_message()),
"failed transfer of tuple buffer " + ShapeUtil::HumanString(shape));
}
// Create a DeviceMemoryBase from each void* pointer.
std::vector<se::DeviceMemoryBase> destination;
for (int i = 0; i < element_pointers.size(); ++i) {
if (element_pointers[i] == nullptr &&
!ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
return FailedPrecondition("tuple contains nullptr at element %d", i);
}
int64 buffer_size =
ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), sizeof(void*));
destination.emplace_back(element_pointers[i], buffer_size);
}
return std::move(destination);
}
Status ExecutorTransferManager::TransferLiteralToDevice(
se::StreamExecutor* executor, const Literal& literal,
se::DeviceMemoryBase* destination) {
const Shape& shape = literal.shape();
if (ShapeUtil::IsTuple(literal.shape())) {
std::vector<void*> tuple_elements_on_device;
for (const Literal& tuple_element : literal.tuple_literals()) {
se::DeviceMemoryBase allocation = executor->AllocateArray<uint8>(
GetByteSizeRequirement(tuple_element.shape()));
TF_RETURN_IF_ERROR(
TransferLiteralToDevice(executor, tuple_element, &allocation));
tuple_elements_on_device.push_back(allocation.opaque());
}
return TransferBufferToDevice(
executor, tuple_elements_on_device.size() * sizeof(void*),
tuple_elements_on_device.data(), destination);
}
return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
literal.InternalData(),
destination);
}
Status ExecutorTransferManager::TransferLiteralToInfeed(
se::StreamExecutor* executor, const Literal& literal) {
const Shape& shape = literal.shape();
VLOG(1) << "transferring literal shape to infeed: "
<< ShapeUtil::HumanString(shape);
return Status::OK();
}
Status ExecutorTransferManager::TransferBufferToInfeed(
se::StreamExecutor* executor, int64 size, const void* source) {
return Unimplemented("Transfer to Infeed");
}
Status ExecutorTransferManager::TransferLiteralFromOutfeed(
perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
Literal* literal) {
const Shape& shape = literal->shape();
VLOG(1) << "transferring literal shape from outfeed: "
<< ShapeUtil::HumanString(shape);
return Status::OK();
}
Status ExecutorTransferManager::ResetDevices(
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
executors) {
return Unimplemented("Device reset not supported");
}
int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) {
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
} // namespace executorplugin
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateExecutorTransferManager() {
return xla::MakeUnique<xla::executorplugin::ExecutorTransferManager>();
}
static bool InitModule() {
xla::TransferManager::RegisterTransferManager(sep::kExecutorPlatformId,
&CreateExecutorTransferManager);
return true;
}
static bool module_initialized = InitModule();

View File

@ -1,77 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_
#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
#include <vector>
namespace se = ::perftools::gputools;
namespace xla {
namespace executorplugin {
class ExecutorTransferManager : public TransferManager {
public:
ExecutorTransferManager();
~ExecutorTransferManager() override {}
se::Platform::Id PlatformId() const override;
StatusOr<std::vector<se::DeviceMemoryBase>> ShallowCopyTupleFromDevice(
se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
const Shape& shape) override;
Status TransferLiteralFromDevice(se::StreamExecutor* executor,
const se::DeviceMemoryBase& source,
const Shape& device_shape,
const Shape& literal_shape,
Literal* literal) override;
Status TransferLiteralToDevice(se::StreamExecutor* executor,
const Literal& literal,
se::DeviceMemoryBase* destination) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const Literal& literal) override;
Status TransferBufferToInfeed(se::StreamExecutor* executor,
int64 size, const void* source) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
Literal* literal) override;
Status ResetDevices(
tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
int64 GetByteSizeRequirement(const Shape& shape) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorTransferManager);
};
} // namespace executorplugin
} // namespace xla
#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_

View File

@ -529,6 +529,17 @@ cc_library(
], ],
) )
cc_library(
name = "interpreter_plugin",
deps = [
":interpreter_transfer_manager",
":service",
"//tensorflow/compiler/xla/service/interpreter:compiler",
"//tensorflow/compiler/xla/service/interpreter:platform",
"//tensorflow/core:stream_executor_no_cuda",
],
)
cc_library( cc_library(
name = "shaped_buffer", name = "shaped_buffer",
srcs = ["shaped_buffer.cc"], srcs = ["shaped_buffer.cc"],
@ -1152,6 +1163,7 @@ cc_library(
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/interpreter:platform_id",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
], ],
@ -1200,6 +1212,27 @@ cc_library(
alwayslink = True, # Contains per-platform transfer manager registration alwayslink = True, # Contains per-platform transfer manager registration
) )
cc_library(
name = "interpreter_transfer_manager",
srcs = ["interpreter_transfer_manager.cc"],
hdrs = ["interpreter_transfer_manager.h"],
deps = [
":generic_transfer_manager",
":transfer_manager",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/interpreter:platform_id",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
alwayslink = True, # Contains per-platform transfer manager registration
)
cc_test( cc_test(
name = "transfer_manager_test", name = "transfer_manager_test",
srcs = ["transfer_manager_test.cc"], srcs = ["transfer_manager_test.cc"],

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
@ -36,19 +37,16 @@ namespace xla {
GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id) GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id)
: platform_id_(platform_id) { : platform_id_(platform_id) {
// We currently only support kHostPlatformId for CPU and kCudaPlatformId for // We currently only support kHostPlatformId for CPU, kCudaPlatformId for
// GPU. Before supporting other platforms, we need to test this transfer // GPU and kInterpreterPlatformId for Interpreter. Before supporting other
// manager on them. // platforms, we need to test this transfer manager on them.
CHECK(platform_id_ == se::host::kHostPlatformId || CHECK(platform_id_ == se::host::kHostPlatformId ||
platform_id_ == se::interpreter::kInterpreterPlatformId ||
platform_id_ == se::cuda::kCudaPlatformId); platform_id_ == se::cuda::kCudaPlatformId);
} }
se::Platform::Id GenericTransferManager::PlatformId() const { se::Platform::Id GenericTransferManager::PlatformId() const {
if (platform_id_ == se::cuda::kCudaPlatformId || return platform_id_;
platform_id_ == se::host::kHostPlatformId) {
return platform_id_;
}
CHECK(false) << "GenericTransferManager::platform_id_ is invalid";
} }
Status GenericTransferManager::TransferLiteralFromDevice( Status GenericTransferManager::TransferLiteralFromDevice(

View File

@ -75,7 +75,7 @@ class GenericTransferManager : public TransferManager {
private: private:
// The platform this transfer manager targets. // The platform this transfer manager targets.
perftools::gputools::Platform::Id platform_id_; const perftools::gputools::Platform::Id platform_id_;
TF_DISALLOW_COPY_AND_ASSIGN(GenericTransferManager); TF_DISALLOW_COPY_AND_ASSIGN(GenericTransferManager);
}; };

View File

@ -0,0 +1,113 @@
licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "compiler",
srcs = ["compiler.cc"],
hdrs = ["compiler.h"],
deps = [
":executable",
":platform_id",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor",
],
alwayslink = True, # Contains compiler registration
)
cc_library(
name = "platform_id",
srcs = ["platform_id.cc"],
hdrs = ["platform_id.h"],
deps = [
"//tensorflow/core:stream_executor_headers_lib",
"@nsync//:nsync_headers",
"@protobuf_archive//:protobuf_headers",
"@protobuf_archive//:protoc_lib",
],
)
cc_library(
name = "executable",
srcs = ["executable.cc"],
hdrs = ["executable.h"],
deps = [
":executor",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
)
cc_library(
name = "platform",
srcs = ["platform.cc"],
hdrs = ["platform.h"],
deps = [
":executor",
":platform_id",
"//tensorflow/core:stream_executor_headers_lib",
],
alwayslink = True, # Registers itself with the MultiPlatformManager.
)
cc_library(
name = "executor",
srcs = ["executor.cc"],
hdrs = ["executor.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_headers_lib",
"//tensorflow/core:stream_executor_no_cuda",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)

View File

@ -0,0 +1,19 @@
# XLA Interpreter Backend
The XLA Interpreter backend operates at HLO-level by ingesting a HloModule and
evaluating the result of the HLO graph directly with HloEvaluator, without
lowering it further (to LLVM IR for example) before execution as other backends
(CPU and GPU for example) do.
Its key componenets are:
* [`InterpreterCompiler`] despite the inherited naming of "compiler", all
`InterpreterCompiler` really does is the following:
1. Runs certain HLO optimization passes on the given HLO graph.
2. Generates an `InterpreterExecutable` from the optimized HLO graph.
3. Registers itself in the global compiler factory registry.
* [`InterpreterExecutable`]: responsible for running input HLO graph through
the `HloEvaluator`, allocating output buffer and finally copying evaluated
Literal result over.
* [`HloEvaluator`]: traverses a HLO graph and evaluates each node in DFS
ordering along the way.

View File

@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <stdlib.h> #include "tensorflow/compiler/xla/service/interpreter/compiler.h"
#include <fstream>
#include "tensorflow/compiler/plugin/executor/compiler.h" #include <string>
#include "tensorflow/compiler/plugin/executor/executable.h" #include <utility>
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@ -28,26 +29,27 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/lib/strcat.h"
namespace xla { namespace xla {
namespace executorplugin { namespace interpreter {
namespace se = ::perftools::gputools; namespace se = ::perftools::gputools;
namespace sep = ::perftools::gputools::executorplugin; namespace sep = ::perftools::gputools::interpreter;
/* /*
* Run optimization passes on the module. The graph is transformed by * Run optimization passes on the module. The graph is transformed by
* each pass in the optimization pipeline. The service subdirectory * each pass in the optimization pipeline. The service subdirectory
* contains useful optimization passes. * contains useful optimization passes.
*/ */
Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) { Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Executor"); HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<Inliner>(); pipeline.AddPass<Inliner>();
pipeline.AddPass<HloSubcomputationUnification>(); pipeline.AddPass<HloSubcomputationUnification>();
pipeline.AddPass<HloCSE>(false); pipeline.AddPass<HloCSE>(false);
@ -65,9 +67,8 @@ Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) {
return pipeline.Run(hlo_module).status(); return pipeline.Run(hlo_module).status();
} }
StatusOr<std::unique_ptr<Executable>> ExecutorCompiler::Compile( StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::Compile(
std::unique_ptr<HloModule> hlo_module, std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec) {
se::StreamExecutor* stream_exec) {
TF_RET_CHECK(stream_exec != nullptr); TF_RET_CHECK(stream_exec != nullptr);
VLOG(1) << "Generate graph " << hlo_module->name(); VLOG(1) << "Generate graph " << hlo_module->name();
@ -75,53 +76,54 @@ StatusOr<std::unique_ptr<Executable>> ExecutorCompiler::Compile(
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
// Typically you would visit the HLO graph, building up a compiled equivalent // Typically you would visit the HLO graph, building up a compiled equivalent
// In this case we are using an Hlo evaluator at execution time, so we don't // In this case we are using an HloEvaluator at execution time, so we don't
// need to compile anything // need to compile anything
// Create executable from only the Hlo module // Create executable from only the Hlo module.
std::unique_ptr<Executable> executable; std::unique_ptr<Executable> executable =
executable.reset(new ExecutorExecutable(std::move(hlo_module))); xla::MakeUnique<InterpreterExecutable>(std::move(hlo_module));
return std::move(executable); return std::move(executable);
} }
StatusOr<std::vector<std::unique_ptr<Executable>>> ExecutorCompiler::Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
std::vector<std::unique_ptr<HloModule>> hlo_modules, std::vector<std::unique_ptr<HloModule>> /*hlo_modules*/,
std::vector<se::StreamExecutor*> stream_execs) { std::vector<se::StreamExecutor*> /*stream_execs*/) {
return tensorflow::errors::Unimplemented( return tensorflow::errors::Unimplemented(
"Compilation of multiple HLO modules is not supported on Executor."); "Compilation of multiple HLO modules is not supported on Interpreter.");
} }
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
ExecutorCompiler::CompileAheadOfTime( InterpreterCompiler::CompileAheadOfTime(
std::vector<std::unique_ptr<HloModule>> hlo_modules, std::vector<std::unique_ptr<HloModule>> hlo_modules,
const AotCompilationOptions& aot_options) { const AotCompilationOptions& aot_options) {
return tensorflow::errors::InvalidArgument( return tensorflow::errors::InvalidArgument(
"AOT compilation not supported on Executor"); "AOT compilation not supported on Interpreter");
} }
se::Platform::Id ExecutorCompiler::PlatformId() const { se::Platform::Id InterpreterCompiler::PlatformId() const {
return sep::kExecutorPlatformId; return sep::kInterpreterPlatformId;
} }
HloCostAnalysis::ShapeSizeFunction HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction()
ExecutorCompiler::ShapeSizeBytesFunction() const { const {
return ExecutorExecutable::ShapeSizeBytes; return InterpreterExecutable::ShapeSizeBytes;
} }
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() { static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
return xla::MakeUnique<xla::ComputationPlacer>(); return xla::MakeUnique<xla::ComputationPlacer>();
} }
REGISTER_MODULE_INITIALIZER(executor_compiler, { static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() { xla::Compiler::RegisterCompilerFactory(sep::kInterpreterPlatformId, []() {
return xla::MakeUnique<xla::executorplugin::ExecutorCompiler>(); return xla::MakeUnique<xla::interpreter::InterpreterCompiler>();
}); });
xla::ComputationPlacer::RegisterComputationPlacer(sep::kExecutorPlatformId, xla::ComputationPlacer::RegisterComputationPlacer(sep::kInterpreterPlatformId,
&CreateComputationPlacer); &CreateComputationPlacer);
}); return true;
}
} // namespace executorplugin static bool module_initialized = InitModule();
} // namespace interpreter
} // namespace xla } // namespace xla

View File

@ -13,38 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_COMPILER_H_
#define TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_COMPILER_H_
#include <memory> #include <memory>
#include <vector>
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/plugin/executor/platform_id.h" #include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace xla { namespace xla {
namespace executorplugin { namespace interpreter {
class ExecutorCompiler : public Compiler { // Despite the inherited "compiler" naming, InterpreterCompiler does not
// perform any lowering as other backends do. It operates at HLO-level for
// and is responsible for generating an InterpreterExecutable.
// Refer to interpreter/README.md for more.
class InterpreterCompiler : public Compiler {
public: public:
ExecutorCompiler() {} InterpreterCompiler() {}
~ExecutorCompiler() override {} ~InterpreterCompiler() override {}
StatusOr<std::unique_ptr<Executable>> Compile( StatusOr<std::unique_ptr<Executable>> Compile(
std::unique_ptr<HloModule> hlo_module, std::unique_ptr<HloModule> hlo_modules,
perftools::gputools::StreamExecutor* stream_exec) override; perftools::gputools::StreamExecutor* stream_exec) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::vector<std::unique_ptr<HloModule>> hlo_module, std::vector<std::unique_ptr<HloModule>> hlo_modules,
std::vector<perftools::gputools::StreamExecutor*> stream_exec) override; std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime( CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> hlo_modules,
std::vector<std::unique_ptr<HloModule>> module, const AotCompilationOptions& aot_options) override;
const AotCompilationOptions& options) override;
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
@ -53,10 +62,10 @@ class ExecutorCompiler : public Compiler {
private: private:
Status RunHloOptimization(HloModule* hlo_module); Status RunHloOptimization(HloModule* hlo_module);
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler); TF_DISALLOW_COPY_AND_ASSIGN(InterpreterCompiler);
}; };
} // namespace executorplugin } // namespace interpreter
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_ #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_COMPILER_H_

View File

@ -13,25 +13,41 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/plugin/executor/executable.h" #include "tensorflow/compiler/xla/service/interpreter/executable.h"
#include "tensorflow/compiler/plugin/executor/executor.h"
#include <algorithm>
#include <cstring>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla { namespace xla {
namespace executorplugin { namespace interpreter {
namespace se = ::perftools::gputools; namespace se = ::perftools::gputools;
namespace sep = ::perftools::gputools::executorplugin; namespace sep = ::perftools::gputools::interpreter;
ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module) InterpreterExecutable::InterpreterExecutable(
std::unique_ptr<HloModule> hlo_module)
: Executable(std::move(hlo_module)) {} : Executable(std::move(hlo_module)) {}
ExecutorExecutable::~ExecutorExecutable() {} InterpreterExecutable::~InterpreterExecutable() {}
static se::DeviceMemoryBase AllocateSingleOutput( static se::DeviceMemoryBase AllocateSingleOutput(
sep::ExecutorExecutor* executor, const Literal& literal) { sep::InterpreterExecutor* executor, const Literal& literal) {
int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape()));
void* buf = executor->Allocate(size); void* buf = executor->Allocate(size);
const void* src = literal.InternalData(); const void* src = literal.InternalData();
@ -40,7 +56,7 @@ static se::DeviceMemoryBase AllocateSingleOutput(
} }
static se::DeviceMemoryBase AllocateOutputBuffer( static se::DeviceMemoryBase AllocateOutputBuffer(
sep::ExecutorExecutor* executor, const Literal& literal) { sep::InterpreterExecutor* executor, const Literal& literal) {
const Shape& shape = literal.shape(); const Shape& shape = literal.shape();
if (shape.element_type() != xla::TUPLE) { if (shape.element_type() != xla::TUPLE) {
return AllocateSingleOutput(executor, literal); return AllocateSingleOutput(executor, literal);
@ -58,7 +74,7 @@ static se::DeviceMemoryBase AllocateOutputBuffer(
} }
} }
StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream( StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments, tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
HloExecutionProfile* hlo_execution_profile) { HloExecutionProfile* hlo_execution_profile) {
@ -82,7 +98,7 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream(
// Create the arguments as an vector of XLA literals // Create the arguments as an vector of XLA literals
std::vector<std::unique_ptr<Literal>> arg_literals; std::vector<std::unique_ptr<Literal>> arg_literals;
std::vector<Literal*> arg_literals_ptrs; std::vector<Literal*> arg_literals_ptrs;
for (int64 p = 0; p < computation->num_parameters(); p++) { for (int64 p = 0; p < computation->num_parameters(); ++p) {
// Create the input literal for the parameter // Create the input literal for the parameter
HloInstruction* param = computation->parameter_instruction(p); HloInstruction* param = computation->parameter_instruction(p);
arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); arg_literals.emplace_back(Literal::CreateFromShape(param->shape()));
@ -94,18 +110,18 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream(
ShapeUtil::ByteSizeOf(param->shape())); ShapeUtil::ByteSizeOf(param->shape()));
} }
// Execute the graph using the evaluator // Execute the graph using the HloEvaluator.
HloEvaluator evaluator; HloEvaluator evaluator;
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> output, TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> output,
evaluator.Evaluate(*computation, arg_literals_ptrs)); evaluator.Evaluate(*computation, arg_literals_ptrs));
// Copy the result into the return buffer // Copy the result into the return buffer
perftools::gputools::StreamExecutor* executor(stream->parent()); perftools::gputools::StreamExecutor* executor(stream->parent());
sep::ExecutorExecutor* executorExecutor( sep::InterpreterExecutor* interpreter_executor(
static_cast<sep::ExecutorExecutor*>(executor->implementation())); static_cast<sep::InterpreterExecutor*>(executor->implementation()));
se::DeviceMemoryBase ret = se::DeviceMemoryBase ret =
AllocateOutputBuffer(executorExecutor, *(output.get())); AllocateOutputBuffer(interpreter_executor, *(output.get()));
uint64 end_micros = tensorflow::Env::Default()->NowMicros(); uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@ -118,32 +134,32 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream(
return ret; return ret;
} }
StatusOr<std::unique_ptr<ShapedBuffer>> ExecutorExecutable::ExecuteOnStream( StatusOr<std::unique_ptr<ShapedBuffer>> InterpreterExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) { HloExecutionProfile* hlo_execution_profile) {
return tensorflow::errors::Unimplemented( return tensorflow::errors::Unimplemented(
"ExecuteOnStream is not yet supported on Executor."); "ExecuteOnStream is not yet supported on Interpreter.");
} }
StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream( StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) { tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
return tensorflow::errors::Unimplemented( return tensorflow::errors::Unimplemented(
"ExecuteAsyncOnStream is not yet supported on Executor."); "ExecuteAsyncOnStream is not yet supported on Interpreter.");
} }
/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) { /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
if (ShapeUtil::IsOpaque(shape)) { if (ShapeUtil::IsOpaque(shape)) {
return sizeof(void*); return sizeof(void*);
} }
return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
} }
std::unique_ptr<HloCostAnalysis> ExecutorExecutable::CreateCostAnalysis() std::unique_ptr<HloCostAnalysis> InterpreterExecutable::CreateCostAnalysis()
const { const {
return MakeUnique<HloCostAnalysis>(ShapeSizeBytes); return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
} }
} // namespace executorplugin } // namespace interpreter
} // namespace xla } // namespace xla

View File

@ -13,29 +13,35 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_H_
#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_H_
#include <cstddef>
#include <memory> #include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
namespace xla { namespace xla {
namespace executorplugin { namespace interpreter {
class ExecutorExecutable : public Executable { // Responsible for running a HLO graph through the HloEvaluator and output
// buffer allocation. Refer to interpreter/README.md for more.
class InterpreterExecutable : public Executable {
public: public:
ExecutorExecutable(std::unique_ptr<HloModule> hlo_module); InterpreterExecutable(std::unique_ptr<HloModule> hlo_module);
~ExecutorExecutable() override; ~InterpreterExecutable() override;
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream( StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
@ -58,10 +64,10 @@ class ExecutorExecutable : public Executable {
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override; std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
private: private:
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable); TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
}; };
} // namespace executorplugin } // namespace interpreter
} // namespace xla } // namespace xla
#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_ #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTABLE_H_

View File

@ -13,117 +13,110 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/plugin/executor/executor.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h"
#include <stdlib.h> #include <cstring>
#include <string.h>
#include "tensorflow/compiler/plugin/executor/platform_id.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
namespace perftools { namespace perftools {
namespace gputools { namespace gputools {
namespace executorplugin { namespace interpreter {
host::HostStream *AsExecutorStream(Stream *stream) { host::HostStream *AsExecutorStream(Stream *stream) {
DCHECK(stream != nullptr); DCHECK(stream != nullptr);
return dynamic_cast<host::HostStream *>(stream->implementation()); return dynamic_cast<host::HostStream *>(stream->implementation());
} }
ExecutorExecutor::ExecutorExecutor(const PluginConfig &plugin_config) InterpreterExecutor::InterpreterExecutor(const PluginConfig &plugin_config)
: plugin_config_(plugin_config) {} : plugin_config_(plugin_config) {}
ExecutorExecutor::~ExecutorExecutor() {} InterpreterExecutor::~InterpreterExecutor() {}
void *ExecutorExecutor::Allocate(uint64 size) { return new char[size]; } void *InterpreterExecutor::Allocate(uint64 size) { return new char[size]; }
void *ExecutorExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, void *InterpreterExecutor::AllocateSubBuffer(DeviceMemoryBase *parent,
uint64 offset_bytes, uint64 offset_bytes,
uint64 size_bytes) { uint64 /*size_bytes*/) {
return parent + offset_bytes; return parent + offset_bytes;
} }
void ExecutorExecutor::Deallocate(DeviceMemoryBase *mem) { void InterpreterExecutor::Deallocate(DeviceMemoryBase *mem) {
if (!mem->is_sub_buffer()) { if (!mem->is_sub_buffer()) {
delete[] static_cast<char *>(mem->opaque()); delete[] static_cast<char *>(mem->opaque());
} }
} }
bool ExecutorExecutor::Memcpy(Stream *stream, void *host_dst, bool InterpreterExecutor::Memcpy(Stream *stream, void *host_dst,
const DeviceMemoryBase &dev_src, uint64 size) { const DeviceMemoryBase &dev_src, uint64 size) {
AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() {
port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); port::Status ok = SynchronousMemcpy(host_dst, dev_src, size);
}); });
return true; return true;
} }
bool ExecutorExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, bool InterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst,
const void *host_src, uint64 size) { const void *host_src, uint64 size) {
AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() {
port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); port::Status ok = SynchronousMemcpy(dev_dst, host_src, size);
}); });
return true; return true;
} }
port::Status ExecutorExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, port::Status InterpreterExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst,
const void *host_src, const void *host_src,
uint64 size) { uint64 size) {
memcpy(dev_dst->opaque(), host_src, size); memcpy(dev_dst->opaque(), host_src, size);
return port::Status::OK(); return port::Status::OK();
} }
port::Status ExecutorExecutor::SynchronousMemcpy(void *host_dst, port::Status InterpreterExecutor::SynchronousMemcpy(
const DeviceMemoryBase &dev_src, void *host_dst, const DeviceMemoryBase &dev_src, uint64 size) {
uint64 size) {
memcpy(host_dst, dev_src.opaque(), size); memcpy(host_dst, dev_src.opaque(), size);
return port::Status::OK(); return port::Status::OK();
} }
bool ExecutorExecutor::HostCallback(Stream *stream, bool InterpreterExecutor::HostCallback(Stream *stream,
std::function<void()> callback) { std::function<void()> callback) {
AsExecutorStream(stream)->EnqueueTask(callback); AsExecutorStream(stream)->EnqueueTask(callback);
return true; return true;
} }
bool ExecutorExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { bool InterpreterExecutor::CreateStreamDependency(Stream *dependent,
Stream *other) {
AsExecutorStream(dependent)->EnqueueTask( AsExecutorStream(dependent)->EnqueueTask(
[other]() { other->BlockHostUntilDone(); }); [other]() { other->BlockHostUntilDone(); });
AsExecutorStream(dependent)->BlockUntilDone(); AsExecutorStream(dependent)->BlockUntilDone();
return true; return true;
} }
bool ExecutorExecutor::StartTimer(Stream *stream, Timer *timer) { bool InterpreterExecutor::StartTimer(Stream *stream, Timer *timer) {
dynamic_cast<host::HostTimer *>(timer->implementation())->Start(stream); dynamic_cast<host::HostTimer *>(timer->implementation())->Start(stream);
return true; return true;
} }
bool ExecutorExecutor::StopTimer(Stream *stream, Timer *timer) { bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) {
dynamic_cast<host::HostTimer *>(timer->implementation())->Stop(stream); dynamic_cast<host::HostTimer *>(timer->implementation())->Stop(stream);
return true; return true;
} }
bool ExecutorExecutor::BlockHostUntilDone(Stream *stream) { bool InterpreterExecutor::BlockHostUntilDone(Stream *stream) {
AsExecutorStream(stream)->BlockUntilDone(); AsExecutorStream(stream)->BlockUntilDone();
return true; return true;
} }
DeviceDescription *ExecutorExecutor::PopulateDeviceDescription() const { DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const {
internal::DeviceDescriptionBuilder builder; internal::DeviceDescriptionBuilder builder;
builder.set_device_address_bits(64); builder.set_device_address_bits(64);
builder.set_name("Executor"); builder.set_name("Interpreter");
builder.set_device_vendor("VectorName");
builder.set_platform_version("1.0");
builder.set_driver_version("1.0");
builder.set_runtime_version("1.0");
builder.set_pci_bus_id("1");
builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024); builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024);
builder.set_clock_rate_ghz(static_cast<float>(CLOCKS_PER_SEC) / 1e9); builder.set_clock_rate_ghz(static_cast<float>(CLOCKS_PER_SEC) / 1e9);
return builder.Build().release(); return builder.Build().release();
} }
} // namespace executorplugin } // namespace interpreter
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools

View File

@ -13,38 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// Declares the ExecutorExecutor class, which is a CPU-only implementation of // Declares the InterpreterExecutor class, which is a CPU-only implementation of
// the StreamExecutor interface. For now, this is used for testing and to // the StreamExecutor interface. For now, this is used for testing and to
// examine the performance of host-based StreamExecutor code. // examine the performance of host-based StreamExecutor code.
#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
#include "tensorflow/stream_executor/host/host_stream.h" #include <functional>
#include "tensorflow/stream_executor/host/host_timer.h" #include <memory>
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/device_description.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/device_options.h"
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/host/host_stream.h"
#include "tensorflow/stream_executor/host/host_timer.h"
#include "tensorflow/stream_executor/kernel.h"
#include "tensorflow/stream_executor/kernel_spec.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/plugin.h"
#include "tensorflow/stream_executor/rng.h" #include "tensorflow/stream_executor/rng.h"
#include "tensorflow/stream_executor/shared_memory_config.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/timer.h"
#include <list>
#include <mutex>
namespace perftools { namespace perftools {
namespace gputools { namespace gputools {
namespace executorplugin { namespace interpreter {
using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>; using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>;
class ExecutorExecutor : public internal::StreamExecutorInterface { class InterpreterExecutor : public internal::StreamExecutorInterface {
public: public:
explicit ExecutorExecutor(const PluginConfig &plugin_config); explicit InterpreterExecutor(const PluginConfig &plugin_config);
~ExecutorExecutor() override; ~InterpreterExecutor() override;
port::Status Init(int device_ordinal, DeviceOptions device_options) override { port::Status Init(int device_ordinal, DeviceOptions device_options) override {
return port::Status::OK(); return port::Status::OK();
@ -194,9 +203,6 @@ class ExecutorExecutor : public internal::StreamExecutorInterface {
return std::unique_ptr<internal::TimerInterface>(new host::HostTimer()); return std::unique_ptr<internal::TimerInterface>(new host::HostTimer());
} }
port::StatusOr<DeviceMemoryBase> ExecuteGraph(const xla::Shape &shape,
Args args);
private: private:
DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);
@ -206,8 +212,8 @@ class ExecutorExecutor : public internal::StreamExecutorInterface {
const PluginConfig plugin_config_; const PluginConfig plugin_config_;
}; };
} // namespace executorplugin } // namespace interpreter
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools
#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_ #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_

View File

@ -13,37 +13,39 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/plugin/executor/platform.h" #include "tensorflow/compiler/xla/service/interpreter/platform.h"
#include "tensorflow/compiler/plugin/executor/executor.h"
#include "tensorflow/compiler/plugin/executor/platform_id.h"
#include "tensorflow/stream_executor/lib/error.h" #include <utility>
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/stream_executor/device_options.h"
#include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/ptr_util.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/lib/status_macros.h"
#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
namespace se = ::perftools::gputools; namespace se = ::perftools::gputools;
namespace sep = ::perftools::gputools::executorplugin; namespace sep = ::perftools::gputools::interpreter;
namespace perftools { namespace perftools {
namespace gputools { namespace gputools {
namespace executorplugin { namespace interpreter {
PLATFORM_DEFINE_ID(kExecutorPlatformId); InterpreterPlatform::InterpreterPlatform() : name_("Interpreter") {}
ExecutorPlatform::ExecutorPlatform() : name_("Executor") {} InterpreterPlatform::~InterpreterPlatform() {}
ExecutorPlatform::~ExecutorPlatform() {} Platform::Id InterpreterPlatform::id() const { return kInterpreterPlatformId; }
Platform::Id ExecutorPlatform::id() const { return kExecutorPlatformId; } int InterpreterPlatform::VisibleDeviceCount() const { return 1; }
int ExecutorPlatform::VisibleDeviceCount() const { return 1; } const string& InterpreterPlatform::Name() const { return name_; }
const string& ExecutorPlatform::Name() const { return name_; } port::StatusOr<StreamExecutor*> InterpreterPlatform::ExecutorForDevice(
port::StatusOr<StreamExecutor*> ExecutorPlatform::ExecutorForDevice(
int ordinal) { int ordinal) {
StreamExecutorConfig config; StreamExecutorConfig config;
config.ordinal = ordinal; config.ordinal = ordinal;
@ -53,7 +55,7 @@ port::StatusOr<StreamExecutor*> ExecutorPlatform::ExecutorForDevice(
} }
port::StatusOr<StreamExecutor*> port::StatusOr<StreamExecutor*>
ExecutorPlatform::ExecutorForDeviceWithPluginConfig( InterpreterPlatform::ExecutorForDeviceWithPluginConfig(
int device_ordinal, const PluginConfig& plugin_config) { int device_ordinal, const PluginConfig& plugin_config) {
StreamExecutorConfig config; StreamExecutorConfig config;
config.ordinal = device_ordinal; config.ordinal = device_ordinal;
@ -62,16 +64,16 @@ ExecutorPlatform::ExecutorForDeviceWithPluginConfig(
return GetExecutor(config); return GetExecutor(config);
} }
port::StatusOr<StreamExecutor*> ExecutorPlatform::GetExecutor( port::StatusOr<StreamExecutor*> InterpreterPlatform::GetExecutor(
const StreamExecutorConfig& config) { const StreamExecutorConfig& config) {
return executor_cache_.GetOrCreate( return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); }); config, [&]() { return GetUncachedExecutor(config); });
} }
port::StatusOr<std::unique_ptr<StreamExecutor>> port::StatusOr<std::unique_ptr<StreamExecutor>>
ExecutorPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { InterpreterPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = port::MakeUnique<StreamExecutor>( auto executor = port::MakeUnique<StreamExecutor>(
this, port::MakeUnique<ExecutorExecutor>(config.plugin_config)); this, port::MakeUnique<InterpreterExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options); auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) { if (!init_status.ok()) {
return port::Status{ return port::Status{
@ -84,27 +86,30 @@ ExecutorPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
return std::move(executor); return std::move(executor);
} }
void ExecutorPlatform::RegisterTraceListener( void InterpreterPlatform::RegisterTraceListener(
std::unique_ptr<TraceListener> listener) { std::unique_ptr<TraceListener> listener) {
LOG(FATAL) << "not yet implemented: register executor trace listener"; LOG(FATAL) << "not yet implemented: register executor trace listener";
} }
void ExecutorPlatform::UnregisterTraceListener(TraceListener* listener) { void InterpreterPlatform::UnregisterTraceListener(TraceListener* listener) {
LOG(FATAL) << "not yet implemented: unregister executor trace listener"; LOG(FATAL) << "not yet implemented: unregister executor trace listener";
} }
static void InitializeExecutorPlatform() { static void InitializeInterpreterPlatform() {
std::unique_ptr<se::Platform> platform(new sep::ExecutorPlatform); std::unique_ptr<se::Platform> platform(new sep::InterpreterPlatform);
SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform))); SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform)));
} }
} // namespace executorplugin } // namespace interpreter
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools
REGISTER_MODULE_INITIALIZER(executor_platform, sep::InitializeExecutorPlatform()); REGISTER_MODULE_INITIALIZER(interpreter_platform,
sep::InitializeInterpreterPlatform());
DECLARE_MODULE_INITIALIZER(multi_platform_manager); DECLARE_MODULE_INITIALIZER(multi_platform_manager);
// Note that module initialization sequencing is not supported in the // Note that module initialization sequencing is not supported in the
// open-source project, so this will be a no-op there. // open-source project, so this will be a no-op there.
REGISTER_MODULE_INITIALIZER_SEQUENCE(executor_platform, multi_platform_manager); REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform,
multi_platform_manager);

View File

@ -12,38 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_H_
#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_H_
#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/plugin.h"
#include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/trace_listener.h" #include "tensorflow/stream_executor/trace_listener.h"
namespace perftools { namespace perftools {
namespace gputools { namespace gputools {
namespace executorplugin { namespace interpreter {
class ExecutorPlatform : public Platform { class InterpreterPlatform : public Platform {
public: public:
ExecutorPlatform(); InterpreterPlatform();
~ExecutorPlatform() override; ~InterpreterPlatform() override;
Platform::Id id() const override; Platform::Id id() const override;
// Device count is less clear-cut for CPUs than accelerators. This call
// currently returns the number of thread units in the host, as reported by
// base::NumCPUs().
int VisibleDeviceCount() const override; int VisibleDeviceCount() const override;
const string& Name() const override; const string& Name() const override;
@ -70,11 +60,11 @@ class ExecutorPlatform : public Platform {
// Cache of created StreamExecutors. // Cache of created StreamExecutors.
ExecutorCache executor_cache_; ExecutorCache executor_cache_;
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorPlatform); SE_DISALLOW_COPY_AND_ASSIGN(InterpreterPlatform);
}; };
} // namespace executorplugin } // namespace interpreter
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools
#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_ #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_H_

View File

@ -0,0 +1,25 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
namespace perftools {
namespace gputools {
namespace interpreter {
PLATFORM_DEFINE_ID(kInterpreterPlatformId);
} // namespace interpreter
} // namespace gputools
} // namespace perftools

View File

@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_ID_H_
#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_ID_H_
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
namespace perftools { namespace perftools {
namespace gputools { namespace gputools {
namespace executorplugin { namespace interpreter {
extern const Platform::Id kExecutorPlatformId; extern const Platform::Id kInterpreterPlatformId;
} // namespace executorplugin } // namespace interpreter
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_ #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_ID_H_

View File

@ -0,0 +1,44 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/interpreter_transfer_manager.h"
#include <memory>
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
namespace sei = ::perftools::gputools::interpreter;
namespace xla {
InterpreterTransferManager::InterpreterTransferManager()
: GenericTransferManager(sei::kInterpreterPlatformId) {}
} // namespace xla
static std::unique_ptr<xla::TransferManager>
CreateInterpreterTransferManager() {
return xla::MakeUnique<xla::InterpreterTransferManager>();
}
static bool InitModule() {
xla::TransferManager::RegisterTransferManager(
sei::kInterpreterPlatformId, &CreateInterpreterTransferManager);
return true;
}
static bool module_initialized = InitModule();

View File

@ -0,0 +1,36 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
// An implementation of the XLA GenericTransferManager for interpreter backend.
class InterpreterTransferManager : public GenericTransferManager {
public:
InterpreterTransferManager();
~InterpreterTransferManager() override = default;
private:
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterTransferManager);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_

View File

@ -107,15 +107,10 @@ cc_binary(
) )
cc_binary( cc_binary(
name = "replay_computation_hlo_evaluator", name = "replay_computation_interpreter",
deps = [ deps = [
":replay_computation_library", ":replay_computation_library",
"//tensorflow/compiler/plugin/executor:plugin_lib", "//tensorflow/compiler/xla/service:interpreter_plugin",
# TODO: This dependency is a workaround for linking error with clang.
# Without it, linker complains about missing symbols from
# 'xla_device_launch_op'. This dependency should be propagated from
# plugin_lib instead, but no targets other than this break without it.
"//tensorflow/compiler/jit",
], ],
) )