Merge pull request #5267 from benoitsteiner/master
Initial support for OpenCL
This commit is contained in:
commit
f179e0a16e
104
configure
vendored
104
configure
vendored
@ -116,6 +116,17 @@ GEN_GIT_SOURCE=tensorflow/tools/git/gen_git_source.py
|
||||
chmod a+x ${GEN_GIT_SOURCE}
|
||||
"${PYTHON_BIN_PATH}" ${GEN_GIT_SOURCE} --configure "${SOURCE_BASE_DIR}"
|
||||
|
||||
## Set up SYCL-related environment settings
|
||||
while [ "$TF_NEED_OPENCL" == "" ]; do
|
||||
read -p "Do you wish to build TensorFlow with OpenCL support? [y/N] " INPUT
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=1;;
|
||||
[Nn]* ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;;
|
||||
"" ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
esac
|
||||
done
|
||||
|
||||
## Set up Cuda-related environment settings
|
||||
|
||||
while [ "$TF_NEED_CUDA" == "" ]; do
|
||||
@ -129,12 +140,14 @@ while [ "$TF_NEED_CUDA" == "" ]; do
|
||||
done
|
||||
|
||||
export TF_NEED_CUDA
|
||||
if [ "$TF_NEED_CUDA" == "0" ]; then
|
||||
export TF_NEED_SYCL
|
||||
if [[ "$TF_NEED_CUDA" == "0" ]] && [[ "$TF_NEED_OPENCL" == "0" ]]; then
|
||||
echo "Configuration finished"
|
||||
bazel_clean_and_fetch
|
||||
exit
|
||||
fi
|
||||
|
||||
if [ "$TF_NEED_CUDA" == "1" ]; then
|
||||
# Set up which gcc nvcc should use as the host compiler
|
||||
while true; do
|
||||
fromuser=""
|
||||
@ -336,6 +349,95 @@ EOF
|
||||
TF_CUDA_COMPUTE_CAPABILITIES=""
|
||||
done
|
||||
|
||||
# end of if "$TF_NEED_CUDA" == "1"
|
||||
fi
|
||||
|
||||
# OpenCL configuration
|
||||
|
||||
if [ "$TF_NEED_OPENCL" == "1" ]; then
|
||||
|
||||
# Determine which C++ compiler should be used as the host compiler
|
||||
while true; do
|
||||
fromuser=""
|
||||
if [ -z "$HOST_CXX_COMPILER" ]; then
|
||||
default_cxx_host_compiler=$(which g++|| true)
|
||||
read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER
|
||||
fromuser="1"
|
||||
if [ -z "$HOST_CXX_COMPILER" ]; then
|
||||
HOST_CXX_COMPILER=$default_cxx_host_compiler
|
||||
fi
|
||||
fi
|
||||
if [ -e "$HOST_CXX_COMPILER" ]; then
|
||||
export HOST_CXX_COMPILER
|
||||
break
|
||||
fi
|
||||
echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2
|
||||
if [ -z "$fromuser" ]; then
|
||||
exit 1
|
||||
fi
|
||||
HOST_CXX_COMPILER=""
|
||||
# Retry
|
||||
done
|
||||
|
||||
# Determine which C compiler should be used as the host compiler
|
||||
while true; do
|
||||
fromuser=""
|
||||
if [ -z "$HOST_C_COMPILER" ]; then
|
||||
default_c_host_compiler=$(which gcc|| true)
|
||||
read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER
|
||||
fromuser="1"
|
||||
if [ -z "$HOST_C_COMPILER" ]; then
|
||||
HOST_C_COMPILER=$default_c_host_compiler
|
||||
fi
|
||||
fi
|
||||
if [ -e "$HOST_C_COMPILER" ]; then
|
||||
export HOST_C_COMPILER
|
||||
break
|
||||
fi
|
||||
echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2
|
||||
if [ -z "$fromuser" ]; then
|
||||
exit 1
|
||||
fi
|
||||
HOST_C_COMPILER=""
|
||||
# Retry
|
||||
done
|
||||
|
||||
while true; do
|
||||
# Configure the OPENCL version to use.
|
||||
TF_OPENCL_VERSION="1.2"
|
||||
|
||||
# Point to ComputeCpp root
|
||||
if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
|
||||
default_computecpp_toolkit_path=/usr/local/computecpp
|
||||
read -p "Please specify the location where ComputeCpp $TF_OPENCL_VERSION is installed. Refer to README.md for more details. [Default is $default_computecpp_toolkit_path]: " COMPUTECPP_TOOLKIT_PATH
|
||||
fromuser="1"
|
||||
if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
|
||||
COMPUTECPP_TOOLKIT_PATH=$default_computecpp_toolkit_path
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$OSNAME" == "Linux" ]; then
|
||||
SYCL_RT_LIB_PATH="lib/libComputeCpp.so"
|
||||
fi
|
||||
|
||||
if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then
|
||||
export COMPUTECPP_TOOLKIT_PATH
|
||||
break
|
||||
fi
|
||||
echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found"
|
||||
|
||||
if [ -z "$fromuser" ]; then
|
||||
exit 1
|
||||
fi
|
||||
# Retry
|
||||
TF_OPENCL_VERSION=""
|
||||
COMPUTECPP_TOOLKIT_PATH=""
|
||||
done
|
||||
|
||||
export TF_NEED_OPENCL
|
||||
# end of if "$TF_NEED_OPENCL" == "1"
|
||||
fi
|
||||
|
||||
bazel_clean_and_fetch
|
||||
|
||||
echo "Configuration finished"
|
||||
|
@ -510,6 +510,7 @@ cc_library(
|
||||
deps = [
|
||||
":core_cpu",
|
||||
":gpu_runtime",
|
||||
":sycl_runtime",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1387,6 +1388,33 @@ tf_cuda_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sycl_runtime",
|
||||
srcs = if_not_windows([
|
||||
"common_runtime/sycl/sycl_device.cc",
|
||||
"common_runtime/sycl/sycl_device_context.cc",
|
||||
"common_runtime/sycl/sycl_device_factory.cc",
|
||||
]),
|
||||
hdrs = if_not_windows([
|
||||
"common_runtime/sycl/sycl_device.h",
|
||||
"common_runtime/sycl/sycl_device_context.h",
|
||||
]),
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":protos_all_cc",
|
||||
"//third_party/eigen3",
|
||||
"@local_config_sycl//sycl:sycl",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
|
||||
|
@ -68,12 +68,17 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
|
||||
(std::vector<DeviceType>{DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
|
||||
types());
|
||||
|
||||
AddDevice("SYCL", "/job:a/replica:0/task:0/device:sycl:0");
|
||||
EXPECT_EQ(
|
||||
(std::vector<DeviceType>{DeviceType(DEVICE_SYCL), DeviceType(DEVICE_GPU),
|
||||
DeviceType(DEVICE_CPU)}), types());
|
||||
|
||||
AddDevice("T1", "/job:a/replica:0/task:0/device:T1:0");
|
||||
AddDevice("T1", "/job:a/replica:0/task:0/device:T1:1");
|
||||
AddDevice("T2", "/job:a/replica:0/task:0/device:T2:0");
|
||||
EXPECT_EQ(
|
||||
(std::vector<DeviceType>{DeviceType("T1"), DeviceType("T2"),
|
||||
DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
|
||||
(std::vector<DeviceType>{DeviceType(DEVICE_SYCL), DeviceType("T1"),
|
||||
DeviceType("T2"), DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
|
||||
types());
|
||||
}
|
||||
|
||||
|
@ -818,6 +818,8 @@ class BlockingOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
|
||||
REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc("");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_SYCL), BlockingOp);
|
||||
|
||||
static void TestSessionInterOpThreadsImpl(bool use_function_lib) {
|
||||
FunctionDefLibrary library_graph_def;
|
||||
if (use_function_lib) {
|
||||
|
88
tensorflow/core/common_runtime/sycl/sycl_device.cc
Normal file
88
tensorflow/core/common_runtime/sycl/sycl_device.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
cl::sycl::gpu_selector s;
|
||||
cl::sycl::queue q(s);
|
||||
|
||||
SYCLDevice::SYCLDevice(const SessionOptions& options, const string& name,
|
||||
Bytes memory_limit, const DeviceLocality& locality,
|
||||
const string& physical_device_desc, Allocator* allocator)
|
||||
: LocalDevice(options,
|
||||
Device::BuildDeviceAttributes(name, DEVICE_SYCL, memory_limit,
|
||||
locality, physical_device_desc),
|
||||
allocator),
|
||||
allocator_(allocator),
|
||||
device_context_(new SYCLDeviceContext()),
|
||||
device_(q) {
|
||||
set_eigen_sycl_device(&device_);
|
||||
}
|
||||
|
||||
SYCLDevice::~SYCLDevice() {
|
||||
device_context_->Unref();
|
||||
}
|
||||
|
||||
void SYCLDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
assert(context);
|
||||
if (port::Tracing::IsActive()) {
|
||||
// TODO(pbar) We really need a useful identifier of the graph node.
|
||||
const uint64 id = Hash64(op_kernel->name());
|
||||
port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
|
||||
id);
|
||||
}
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
Allocator* SYCLDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
Status SYCLDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
ProtoDebugString(tensor_proto));
|
||||
}
|
||||
*tensor = std::move(parsed);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SYCLDevice::FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) {
|
||||
// Fill in the context map. It is OK for this map to contain
|
||||
// duplicate DeviceContexts so long as we increment the refcount.
|
||||
device_context_map->resize(graph->num_node_ids());
|
||||
for (Node* n : graph->nodes()) {
|
||||
device_context_->Ref();
|
||||
(*device_context_map)[n->id()] = device_context_;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
62
tensorflow/core/common_runtime/sycl/sycl_device.h
Normal file
62
tensorflow/core/common_runtime/sycl/sycl_device.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if !TENSORFLOW_USE_SYCL
|
||||
#error This file must only be included when building TensorFlow with SYCL support
|
||||
#endif
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
|
||||
|
||||
#define EIGEN_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class SYCLDevice : public LocalDevice {
|
||||
public:
|
||||
SYCLDevice(const SessionOptions& options, const string& name,
|
||||
Bytes memory_limit, const DeviceLocality& locality,
|
||||
const string& physical_device_desc, Allocator* allocator);
|
||||
~SYCLDevice() override;
|
||||
|
||||
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override;
|
||||
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) override;
|
||||
|
||||
Status FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) override;
|
||||
|
||||
Status Sync() override { return Status::OK(); }
|
||||
static string GetShortDeviceDescription(/*int device_id,
|
||||
const DeviceDescription& desc*/) {
|
||||
return strings::StrCat("device: 0, name SYCL, pci bus id: 0");
|
||||
}
|
||||
|
||||
private:
|
||||
Allocator* allocator_; // Not owned
|
||||
SYCLDeviceContext* device_context_;
|
||||
Eigen::SyclDevice device_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
|
46
tensorflow/core/common_runtime/sycl/sycl_device_context.cc
Normal file
46
tensorflow/core/common_runtime/sycl/sycl_device_context.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const {
|
||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
const void* src_ptr = DMAHelper::base(cpu_tensor);
|
||||
void* dst_ptr = DMAHelper::base(device_tensor);
|
||||
::memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
const int64 total_bytes = device_tensor->TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
const void* src_ptr = DMAHelper::base(device_tensor);
|
||||
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
||||
::memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
42
tensorflow/core/common_runtime/sycl/sycl_device_context.h
Normal file
42
tensorflow/core/common_runtime/sycl/sycl_device_context.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
|
||||
#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class SYCLDeviceContext : public DeviceContext {
|
||||
public:
|
||||
SYCLDeviceContext() {}
|
||||
|
||||
~SYCLDeviceContext() override {}
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const override;
|
||||
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) override;
|
||||
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
|
44
tensorflow/core/common_runtime/sycl/sycl_device_factory.cc
Normal file
44
tensorflow/core/common_runtime/sycl/sycl_device_factory.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class SYCLDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override {
|
||||
int n = 1;
|
||||
auto iter = options.config.device_count().find("SYCL");
|
||||
if (iter != options.config.device_count().end()) {
|
||||
n = iter->second;
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
string name = strings::StrCat(name_prefix, "/device:SYCL:", i);
|
||||
devices->push_back(new SYCLDevice(
|
||||
options, name, Bytes(256 << 20), DeviceLocality(),
|
||||
SYCLDevice::GetShortDeviceDescription(), cpu_allocator()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
@ -30,6 +30,9 @@ limitations under the License.
|
||||
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
struct SyclDevice;
|
||||
#endif
|
||||
} // end namespace Eigen
|
||||
|
||||
namespace perftools {
|
||||
@ -145,6 +148,10 @@ class DeviceBase {
|
||||
eigen_cpu_device_ = d;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; }
|
||||
#endif
|
||||
|
||||
// Return the Allocator implementation to use based on the allocator
|
||||
// attributes requested. See allocator.h for more details.
|
||||
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
|
||||
@ -167,6 +174,13 @@ class DeviceBase {
|
||||
return eigen_cpu_device_;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
const Eigen::SyclDevice* eigen_sycl_device() const {
|
||||
CHECK(eigen_sycl_device_ != nullptr);
|
||||
return eigen_sycl_device_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Caller owns the return value. The OpKernelContext calls this even
|
||||
// for devices that do not implement an eigen_gpu_device. Overridden
|
||||
// by GPU devices to return a derived type.
|
||||
@ -203,6 +217,9 @@ class DeviceBase {
|
||||
CpuWorkerThreads* cpu_worker_threads_ = nullptr;
|
||||
GpuDeviceInfo* gpu_device_info_ = nullptr;
|
||||
Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
Eigen::SyclDevice* eigen_sycl_device_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -949,6 +949,13 @@ const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
|
||||
return eigen_gpu_device();
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <>
|
||||
const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
|
||||
return eigen_sycl_device();
|
||||
}
|
||||
#endif
|
||||
|
||||
void OpKernelConstruction::CtxFailure(Status s) {
|
||||
VLOG(1) << s;
|
||||
SetStatus(s);
|
||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
struct GpuDevice;
|
||||
struct SyclDevice;
|
||||
} // end namespace Eigen
|
||||
|
||||
namespace tensorflow {
|
||||
@ -891,6 +892,11 @@ class OpKernelContext {
|
||||
const Eigen::GpuDevice& eigen_gpu_device() const {
|
||||
return params_->eigen_gpu_device->device();
|
||||
}
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
const Eigen::SyclDevice& eigen_sycl_device() const {
|
||||
return *device()->eigen_sycl_device();
|
||||
}
|
||||
#endif
|
||||
template <typename EigenDeviceType>
|
||||
const EigenDeviceType& eigen_device() const;
|
||||
|
||||
|
@ -37,6 +37,7 @@ std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
|
||||
|
||||
const char* const DEVICE_CPU = "CPU";
|
||||
const char* const DEVICE_GPU = "GPU";
|
||||
const char* const DEVICE_SYCL = "SYCL";
|
||||
|
||||
string DataTypeString(DataType dtype) {
|
||||
if (IsRefType(dtype)) {
|
||||
|
@ -68,8 +68,9 @@ class DeviceType {
|
||||
std::ostream& operator<<(std::ostream& os, const DeviceType& d);
|
||||
|
||||
// Convenient constants that can be passed to a DeviceType constructor
|
||||
extern const char* const DEVICE_CPU; // "CPU"
|
||||
extern const char* const DEVICE_GPU; // "GPU"
|
||||
extern const char* const DEVICE_CPU; // "CPU"
|
||||
extern const char* const DEVICE_GPU; // "GPU"
|
||||
extern const char* const DEVICE_SYCL; // "SYCL"
|
||||
|
||||
typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
|
||||
typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice;
|
||||
|
@ -25,6 +25,7 @@ namespace {
|
||||
TEST(TypesTest, DeviceTypeName) {
|
||||
EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU)));
|
||||
EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU)));
|
||||
EXPECT_EQ("SYCL", DeviceTypeString(DeviceType(DEVICE_SYCL)));
|
||||
}
|
||||
|
||||
TEST(TypesTest, kDataTypeRefOffset) {
|
||||
|
@ -51,6 +51,17 @@ ConstantOp::~ConstantOp() {}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Const") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
ConstantOp);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_KERNEL(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
@ -18,6 +18,14 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
REGISTER5(UnaryOp, CPU, "Round", functor::round, Eigen::half, float, double,
|
||||
int32, int64);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER(UnaryOp, SYCL, "Round", functor::round, float);
|
||||
namespace functor {
|
||||
DEFINE_UNARY1(round, float);
|
||||
} // namespace functor
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER5(UnaryOp, GPU, "Round", functor::round, Eigen::half, float, double,
|
||||
int32, int64);
|
||||
|
@ -20,6 +20,10 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
|
||||
|
||||
@ -33,6 +37,9 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif
|
||||
|
||||
class BinaryOpShared : public OpKernel {
|
||||
public:
|
||||
@ -96,45 +103,45 @@ class BinaryOp : public BinaryOpShared {
|
||||
if (state.in1_num_elements == 1) {
|
||||
// tensor op scalar
|
||||
functor::BinaryFunctor<Device, Functor, 1>().Right(
|
||||
eigen_device, out_flat, in0.flat<Tin>(), in1.scalar<Tin>(),
|
||||
error_ptr);
|
||||
eigen_device, out_flat, in0.template flat<Tin>(),
|
||||
in1.template scalar<Tin>(), error_ptr);
|
||||
} else if (state.in0_num_elements == 1) {
|
||||
// scalar op tensor
|
||||
functor::BinaryFunctor<Device, Functor, 1>().Left(
|
||||
eigen_device, out_flat, in0.scalar<Tin>(), in1.flat<Tin>(),
|
||||
error_ptr);
|
||||
eigen_device, out_flat, in0.template scalar<Tin>(),
|
||||
in1.template flat<Tin>(), error_ptr);
|
||||
} else {
|
||||
functor::BinaryFunctor<Device, Functor, 1>()(
|
||||
eigen_device, out_flat, in0.flat<Tin>(), in1.flat<Tin>(),
|
||||
error_ptr);
|
||||
eigen_device, out_flat, in0.template flat<Tin>(),
|
||||
in1.template flat<Tin>(), error_ptr);
|
||||
}
|
||||
} else if (ndims == 2) {
|
||||
functor::BinaryFunctor<Device, Functor, 2>().BCast(
|
||||
eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
|
||||
in0.shaped<Tin, 2>(bcast->x_reshape()),
|
||||
in0.template shaped<Tin, 2>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<2>(bcast->x_bcast()),
|
||||
in1.shaped<Tin, 2>(bcast->y_reshape()),
|
||||
in1.template shaped<Tin, 2>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
|
||||
} else if (ndims == 3) {
|
||||
functor::BinaryFunctor<Device, Functor, 3>().BCast(
|
||||
eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
|
||||
in0.shaped<Tin, 3>(bcast->x_reshape()),
|
||||
in0.template shaped<Tin, 3>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<3>(bcast->x_bcast()),
|
||||
in1.shaped<Tin, 3>(bcast->y_reshape()),
|
||||
in1.template shaped<Tin, 3>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
|
||||
} else if (ndims == 4) {
|
||||
functor::BinaryFunctor<Device, Functor, 4>().BCast(
|
||||
eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
|
||||
in0.shaped<Tin, 4>(bcast->x_reshape()),
|
||||
in0.template shaped<Tin, 4>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<4>(bcast->x_bcast()),
|
||||
in1.shaped<Tin, 4>(bcast->y_reshape()),
|
||||
in1.template shaped<Tin, 4>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
|
||||
} else if (ndims == 5) {
|
||||
functor::BinaryFunctor<Device, Functor, 5>().BCast(
|
||||
eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
|
||||
in0.shaped<Tin, 5>(bcast->x_reshape()),
|
||||
in0.template shaped<Tin, 5>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<5>(bcast->x_bcast()),
|
||||
in1.shaped<Tin, 5>(bcast->y_reshape()),
|
||||
in1.template shaped<Tin, 5>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
|
||||
} else {
|
||||
SetUnimplementedError(ctx);
|
||||
|
138
tensorflow/core/kernels/cwise_ops_sycl_common.h
Normal file
138
tensorflow/core/kernels/cwise_ops_sycl_common.h
Normal file
@ -0,0 +1,138 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if !TENSORFLOW_USE_SYCL
|
||||
#error This file must only be included when building TensorFlow with SYCL support
|
||||
#endif
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
|
||||
|
||||
#define EIGEN_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
|
||||
template <typename OUT, typename RHS>
|
||||
void Assign(const SYCLDevice& d, OUT out, RHS rhs) {
|
||||
out.device(d) = rhs;
|
||||
}
|
||||
|
||||
// Partial specialization of UnaryFunctor<Device=SYCLDevice, Functor>.
|
||||
template <typename Functor>
|
||||
struct UnaryFunctor<SYCLDevice, Functor> {
|
||||
void operator()(const SYCLDevice& d, typename Functor::tout_type out,
|
||||
typename Functor::tin_type in) {
|
||||
To32Bit(out).device(d) = To32Bit(in).unaryExpr(typename Functor::func());
|
||||
}
|
||||
};
|
||||
|
||||
// Partial specialization of BinaryFunctor<Device=SYCLDevice, Functor>.
|
||||
template <typename Functor, int NDIMS, bool has_errors>
|
||||
struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
|
||||
void operator()(const SYCLDevice& d, typename Functor::tout_type out,
|
||||
typename Functor::tin_type in0,
|
||||
typename Functor::tin_type in1, bool* error) {
|
||||
Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
|
||||
}
|
||||
|
||||
void Left(const SYCLDevice& d, typename Functor::tout_type out,
|
||||
typename Functor::tscalar_type scalar,
|
||||
typename Functor::tin_type in, bool* error) {
|
||||
LOG(FATAL) << "BinaryFunctor::Left NOT IMPLEMENTED ! ";
|
||||
}
|
||||
|
||||
void Right(const SYCLDevice& d, typename Functor::tout_type out,
|
||||
typename Functor::tin_type in,
|
||||
typename Functor::tscalar_type scalar, bool* error) {
|
||||
typedef typename Functor::out_type Tout;
|
||||
typedef typename Functor::in_type Tin;
|
||||
typedef typename Functor::func Binary;
|
||||
typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
|
||||
Assign(d, out, in.unaryExpr(Unary(scalar.data())));
|
||||
}
|
||||
|
||||
void BCast(const SYCLDevice& d,
|
||||
typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
|
||||
typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
|
||||
typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
|
||||
bool* error) {
|
||||
LOG(FATAL) << "BinaryFunctor::BCast NOT IMPLEMENTED ";
|
||||
}
|
||||
};
|
||||
|
||||
// Macros to explicitly instantiate kernels on GPU for multiple types
|
||||
// (T0, T1, etc.) for UnaryFunctor (e.g., functor::sqrt).
|
||||
#define DEFINE_UNARY1(F, T) template struct UnaryFunctor<SYCLDevice, F<T> >
|
||||
#define DEFINE_UNARY2(F, T0, T1) \
|
||||
DEFINE_UNARY1(F, T0); \
|
||||
DEFINE_UNARY1(F, T1)
|
||||
#define DEFINE_UNARY3(F, T0, T1, T2) \
|
||||
DEFINE_UNARY2(F, T0, T1); \
|
||||
DEFINE_UNARY1(F, T2)
|
||||
#define DEFINE_UNARY4(F, T0, T1, T2, T3) \
|
||||
DEFINE_UNARY2(F, T0, T1); \
|
||||
DEFINE_UNARY2(F, T2, T3)
|
||||
#define DEFINE_UNARY5(F, T0, T1, T2, T3, T4) \
|
||||
DEFINE_UNARY2(F, T0, T1); \
|
||||
DEFINE_UNARY3(F, T2, T3, T4)
|
||||
|
||||
// Macros to explicitly instantiate kernels on GPU for multiple types
|
||||
// (T0, T1, etc.) for BinaryFunctor.
|
||||
#define DEFINE_BINARY1(F, T) \
|
||||
template struct BinaryFunctor<SYCLDevice, F<T>, 1>; \
|
||||
template struct BinaryFunctor<SYCLDevice, F<T>, 2>; \
|
||||
template struct BinaryFunctor<SYCLDevice, F<T>, 3>
|
||||
#define DEFINE_BINARY2(F, T0, T1) \
|
||||
DEFINE_BINARY1(F, T0); \
|
||||
DEFINE_BINARY1(F, T1)
|
||||
#define DEFINE_BINARY3(F, T0, T1, T2) \
|
||||
DEFINE_BINARY2(F, T0, T1); \
|
||||
DEFINE_BINARY1(F, T2)
|
||||
#define DEFINE_BINARY4(F, T0, T1, T2, T3) \
|
||||
DEFINE_BINARY2(F, T0, T1); \
|
||||
DEFINE_BINARY2(F, T2, T3)
|
||||
#define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \
|
||||
DEFINE_BINARY2(F, T0, T1); \
|
||||
DEFINE_BINARY3(F, T2, T3, T4)
|
||||
#define DEFINE_BINARY6(F, T0, T1, T2, T3, T4, T5) \
|
||||
DEFINE_BINARY3(F, T0, T1, T2); \
|
||||
DEFINE_BINARY3(F, T3, T4, T5)
|
||||
#define DEFINE_BINARY7(F, T0, T1, T2, T3, T4, T5, T6) \
|
||||
DEFINE_BINARY3(F, T0, T1, T2); \
|
||||
DEFINE_BINARY4(F, T3, T4, T5, T6)
|
||||
#define DEFINE_BINARY8(F, T0, T1, T2, T3, T4, T5, T6, T7) \
|
||||
DEFINE_BINARY4(F, T0, T1, T2, T3); \
|
||||
DEFINE_BINARY4(F, T4, T5, T6, T7)
|
||||
#define DEFINE_BINARY9(F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
|
||||
DEFINE_BINARY4(F, T0, T1, T2, T3); \
|
||||
DEFINE_BINARY5(F, T4, T5, T6, T7, T8)
|
||||
#define DEFINE_BINARY10(F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) \
|
||||
DEFINE_BINARY5(F, T0, T1, T2, T3, T4); \
|
||||
DEFINE_BINARY5(F, T5, T6, T7, T8, T9)
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
|
@ -87,6 +87,29 @@ class RetvalOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
ArgOp);
|
||||
#undef REGISTER
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("input")
|
||||
.TypeConstraint<int32>("T"),
|
||||
RetvalOp);
|
||||
#undef REGISTER
|
||||
#endif
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_Arg").Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
|
||||
|
@ -34,6 +34,24 @@ REGISTER_KERNEL_BUILDER(Name("PlaceholderWithDefault").Device(DEVICE_CPU),
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RefIdentity").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StopGradient").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\
|
||||
IdentityOp)
|
||||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
|
||||
REGISTER_SYCL_KERNEL(bfloat16);
|
||||
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
#endif
|
||||
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
@ -50,6 +68,7 @@ REGISTER_GPU_KERNEL(bfloat16);
|
||||
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// A special GPU kernel for int32 and bool.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
@ -20,4 +20,8 @@ namespace tensorflow {
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_CPU), NoOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_GPU), NoOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_SYCL), NoOp);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -78,6 +78,10 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_SYCL), SendOp);
|
||||
#endif
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp);
|
||||
@ -136,6 +140,10 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_SYCL), RecvOp);
|
||||
#endif
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp);
|
||||
|
@ -10,6 +10,7 @@ exports_files(["LICENSE"])
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
load("@local_config_cuda//cuda:platform.bzl", "cuda_library_path")
|
||||
load("@local_config_sycl//sycl:platform.bzl", "sycl_library_path")
|
||||
|
||||
cc_library(
|
||||
name = "gtest",
|
||||
@ -143,6 +144,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sycl",
|
||||
data = [
|
||||
"@local_config_sycl//sycl:{}".format(sycl_library_path("ComputeCpp")),
|
||||
],
|
||||
linkopts = select({
|
||||
"//conditions:default": [
|
||||
"-Wl,-rpath,../local_config_sycl/sycl/lib",
|
||||
],
|
||||
}),
|
||||
deps = [
|
||||
"@local_config_sycl//sycl:syclrt",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "android_srcs",
|
||||
srcs = glob(["*.h"]),
|
||||
|
@ -142,6 +142,7 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
|
||||
progress = true;
|
||||
}
|
||||
|
||||
// Handle legacy naming convention for cpu and gpu.
|
||||
if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/CPU:")) {
|
||||
p->has_type = true;
|
||||
|
@ -1,11 +1,14 @@
|
||||
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
|
||||
|
||||
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
|
||||
|
||||
|
||||
# If TensorFlow is linked as a submodule.
|
||||
# path_prefix and tf_repo_name are no longer used.
|
||||
def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
cuda_configure(name = "local_config_cuda")
|
||||
sycl_configure(name = "local_config_sycl")
|
||||
if path_prefix:
|
||||
print("path_prefix was specified to tf_workspace but is no longer used and will be removed in the future.")
|
||||
if tf_repo_name:
|
||||
@ -14,8 +17,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
# These lines need to be changed when updating Eigen. They are parsed from
|
||||
# this file by the cmake and make builds to determine the eigen version and
|
||||
# hash.
|
||||
eigen_version = "22b492048b2f"
|
||||
eigen_sha256 = "8b9bd14a037c1a3fe37dc5e4a71504ebe48148cf2498fd8eb6848165a7a0538f"
|
||||
eigen_version = "3f0fb403ec4c"
|
||||
eigen_sha256 = "9ff8301c6af2640932c5ded77ecccee5786cec8c31315311220618b312e0472b"
|
||||
|
||||
native.new_http_archive(
|
||||
name = "eigen_archive",
|
||||
|
5
third_party/eigen3/BUILD
vendored
5
third_party/eigen3/BUILD
vendored
@ -23,5 +23,8 @@ cc_library(
|
||||
"unsupported/Eigen/CXX11/FixedPoint",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["@eigen_archive//:eigen"],
|
||||
deps = [
|
||||
"@eigen_archive//:eigen",
|
||||
"@local_config_sycl//sycl:sycl",
|
||||
],
|
||||
)
|
||||
|
0
third_party/sycl/BUILD
vendored
Normal file
0
third_party/sycl/BUILD
vendored
Normal file
0
third_party/sycl/crosstool/BUILD
vendored
Normal file
0
third_party/sycl/crosstool/BUILD
vendored
Normal file
29
third_party/sycl/crosstool/BUILD.tpl
vendored
Executable file
29
third_party/sycl/crosstool/BUILD.tpl
vendored
Executable file
@ -0,0 +1,29 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_toolchain_suite(
|
||||
name = "toolchain",
|
||||
toolchains = {
|
||||
"local|compiler": ":cc-compiler-local",
|
||||
},
|
||||
)
|
||||
|
||||
cc_toolchain(
|
||||
name = "cc-compiler-local",
|
||||
all_files = ":empty",
|
||||
compiler_files = ":empty",
|
||||
cpu = "local",
|
||||
dwp_files = ":empty",
|
||||
dynamic_runtime_libs = [":empty"],
|
||||
linker_files = ":empty",
|
||||
objcopy_files = ":empty",
|
||||
static_runtime_libs = [":empty"],
|
||||
strip_files = ":empty",
|
||||
supports_param_files = 0,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "empty",
|
||||
srcs = [],
|
||||
)
|
82
third_party/sycl/crosstool/CROSSTOOL.tpl
vendored
Executable file
82
third_party/sycl/crosstool/CROSSTOOL.tpl
vendored
Executable file
@ -0,0 +1,82 @@
|
||||
major_version: "local"
|
||||
minor_version: ""
|
||||
default_target_cpu: "same_as_host"
|
||||
|
||||
default_toolchain {
|
||||
cpu: "k8"
|
||||
toolchain_identifier: "local_linux"
|
||||
}
|
||||
|
||||
toolchain {
|
||||
abi_version: "local"
|
||||
abi_libc_version: "local"
|
||||
builtin_sysroot: ""
|
||||
compiler: "compiler"
|
||||
host_system_name: "local"
|
||||
needsPic: true
|
||||
supports_gold_linker: false
|
||||
supports_incremental_linker: false
|
||||
supports_fission: false
|
||||
supports_interface_shared_objects: false
|
||||
supports_normalizing_ar: false
|
||||
supports_start_end_lib: false
|
||||
supports_thin_archives: false
|
||||
target_libc: "local"
|
||||
target_cpu: "local"
|
||||
target_system_name: "local"
|
||||
toolchain_identifier: "local_linux"
|
||||
|
||||
tool_path { name: "ar" path: "/usr/bin/ar" }
|
||||
tool_path { name: "compat-ld" path: "/usr/bin/ld" }
|
||||
tool_path { name: "cpp" path: "/usr/bin/cpp" }
|
||||
tool_path { name: "dwp" path: "/usr/bin/dwp" }
|
||||
tool_path { name: "gcc" path: "computecpp" }
|
||||
# Use "-std=c++11" for nvcc. For consistency, force both the host compiler
|
||||
# and the device compiler to use "-std=c++11".
|
||||
cxx_flag: "-std=c++11"
|
||||
linker_flag: "-lstdc++"
|
||||
linker_flag: "-B/usr/bin/"
|
||||
|
||||
# TODO(bazel-team): In theory, the path here ought to exactly match the path
|
||||
# used by gcc. That works because bazel currently doesn't track files at
|
||||
# absolute locations and has no remote execution, yet. However, this will need
|
||||
# to be fixed, maybe with auto-detection?
|
||||
cxx_builtin_include_directory: "/usr/lib/gcc/"
|
||||
cxx_builtin_include_directory: "/usr/lib"
|
||||
cxx_builtin_include_directory: "/usr/lib64"
|
||||
cxx_builtin_include_directory: "/usr/local/include"
|
||||
cxx_builtin_include_directory: "/usr/include"
|
||||
|
||||
cxx_builtin_include_directory: "%{computecpp_toolkit_path}"
|
||||
|
||||
tool_path { name: "gcov" path: "/usr/bin/gcov" }
|
||||
|
||||
# C(++) compiles invoke the compiler (as that is the one knowing where
|
||||
# to find libraries), but we provide LD so other rules can invoke the linker.
|
||||
tool_path { name: "ld" path: "/usr/bin/ld" }
|
||||
|
||||
tool_path { name: "nm" path: "/usr/bin/nm" }
|
||||
tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
|
||||
objcopy_embed_flag: "-I"
|
||||
objcopy_embed_flag: "binary"
|
||||
tool_path { name: "objdump" path: "/usr/bin/objdump" }
|
||||
tool_path { name: "strip" path: "/usr/bin/strip" }
|
||||
|
||||
# Make C++ compilation deterministic. Use linkstamping instead of these
|
||||
# compiler symbols.
|
||||
unfiltered_cxx_flag: "-Wno-builtin-macro-redefined"
|
||||
unfiltered_cxx_flag: "-D__DATE__=\"redacted\""
|
||||
unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\""
|
||||
unfiltered_cxx_flag: "-D__TIME__=\"redacted\""
|
||||
|
||||
# All warnings are enabled. Maybe enable -Werror as well?
|
||||
compiler_flag: "-Wall"
|
||||
|
||||
# Anticipated future default.
|
||||
linker_flag: "-Wl,-no-as-needed"
|
||||
# Stamp the binary with a unique identifier.
|
||||
linker_flag: "-Wl,--build-id=md5"
|
||||
linker_flag: "-Wl,--hash-style=gnu"
|
||||
|
||||
linking_mode_flags { mode: DYNAMIC }
|
||||
}
|
61
third_party/sycl/crosstool/computecpp.tpl
vendored
Executable file
61
third_party/sycl/crosstool/computecpp.tpl
vendored
Executable file
@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from argparse import ArgumentParser
|
||||
import os
|
||||
import subprocess
|
||||
import re
|
||||
import sys
|
||||
import pipes
|
||||
|
||||
CPU_CXX_COMPILER = ('%{host_cxx_compiler}')
|
||||
CPU_C_COMPILER = ('%{host_c_compiler}')
|
||||
|
||||
CURRENT_DIR = os.path.dirname(sys.argv[0])
|
||||
COMPUTECPP_ROOT = CURRENT_DIR +"/../sycl/"
|
||||
COMPUTECPP_DRIVER= COMPUTECPP_ROOT+"bin/compute++"
|
||||
COMPUTECPP_INCLUDE = COMPUTECPP_ROOT+"include"
|
||||
|
||||
def main():
|
||||
computecpp_compiler_flags = [""]
|
||||
computecpp_compiler_flags = [flag for flag in sys.argv[1:]]
|
||||
computecpp_compiler_flags = computecpp_compiler_flags + ["-D_GLIBCXX_USE_CXX11_ABI=0"]
|
||||
|
||||
output_file_index = computecpp_compiler_flags.index("-o") +1
|
||||
output_file_name = computecpp_compiler_flags[output_file_index]
|
||||
|
||||
if(output_file_index == 1):
|
||||
# we are linking
|
||||
return subprocess.call([CPU_CXX_COMPILER] +computecpp_compiler_flags )
|
||||
|
||||
# find what we compile
|
||||
compiling_cpp = 0
|
||||
if("-c" in computecpp_compiler_flags):
|
||||
compiled_file_index = computecpp_compiler_flags.index("-c") +1
|
||||
compited_file_name = computecpp_compiler_flags[compiled_file_index]
|
||||
if(compited_file_name.endswith(('.cc', '.c++', '.cpp', '.CPP', '.C', '.cxx'))):
|
||||
compiling_cpp = 1;
|
||||
|
||||
if(compiling_cpp == 1):
|
||||
filename, file_extension = os.path.splitext(output_file_name)
|
||||
bc_out = filename + ".sycl"
|
||||
|
||||
computecpp_compiler_flags = ['-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable','-I', COMPUTECPP_INCLUDE,'-isystem',
|
||||
COMPUTECPP_INCLUDE, "-std=c++11", "-sycl", "-emit-llvm", "-no-serial-memop"] + computecpp_compiler_flags
|
||||
|
||||
# dont want that in case of compiling with computecpp first
|
||||
host_compiler_flags = [""]
|
||||
host_compiler_flags = [flag for flag in sys.argv[1:]
|
||||
if not flag.startswith(('-MF','-MD',))
|
||||
if not ".d" in flag]
|
||||
|
||||
x = subprocess.call([COMPUTECPP_DRIVER] +computecpp_compiler_flags )
|
||||
if(x == 0):
|
||||
host_compiler_flags = ['-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, "--include",bc_out] + host_compiler_flags
|
||||
return subprocess.call([CPU_CXX_COMPILER] +host_compiler_flags )
|
||||
return x
|
||||
else:
|
||||
# compile for C
|
||||
return subprocess.call([CPU_C_COMPILER] +computecpp_compiler_flags)
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
0
third_party/sycl/sycl/BUILD
vendored
Normal file
0
third_party/sycl/sycl/BUILD
vendored
Normal file
43
third_party/sycl/sycl/BUILD.tpl
vendored
Executable file
43
third_party/sycl/sycl/BUILD.tpl
vendored
Executable file
@ -0,0 +1,43 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
|
||||
load("platform", "sycl_library_path")
|
||||
|
||||
load("platform", "readlink_command")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
config_setting(
|
||||
name = "using_sycl",
|
||||
values = {
|
||||
"define": "using_sycl=true",
|
||||
},
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sycl_headers",
|
||||
hdrs = glob([
|
||||
"**/*.h",
|
||||
]),
|
||||
includes = [".", "include"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "syclrt",
|
||||
srcs = [
|
||||
sycl_library_path("ComputeCpp")
|
||||
],
|
||||
data = [
|
||||
sycl_library_path("ComputeCpp")
|
||||
],
|
||||
includes = ["include/"],
|
||||
linkstatic = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sycl",
|
||||
deps = if_sycl([
|
||||
":sycl_headers",
|
||||
":syclrt",
|
||||
]),
|
||||
)
|
13
third_party/sycl/sycl/build_defs.bzl.tpl
vendored
Executable file
13
third_party/sycl/sycl/build_defs.bzl.tpl
vendored
Executable file
@ -0,0 +1,13 @@
|
||||
# Macros for building SYCL code.
|
||||
|
||||
def if_sycl(if_true, if_false = []):
|
||||
"""Shorthand for select()'ing on whether we're building with SYCL.
|
||||
|
||||
Returns a select statement which evaluates to if_true if we're building
|
||||
with SYCL enabled. Otherwise, the select statement evaluates to if_false.
|
||||
|
||||
"""
|
||||
return select({
|
||||
"@local_config_sycl//sycl:using_sycl": if_true,
|
||||
"//conditions:default": if_false
|
||||
})
|
5
third_party/sycl/sycl/platform.bzl.tpl
vendored
Executable file
5
third_party/sycl/sycl/platform.bzl.tpl
vendored
Executable file
@ -0,0 +1,5 @@
|
||||
def sycl_library_path(name):
|
||||
return "lib/lib{}.so".format(name)
|
||||
|
||||
def readlink_command():
|
||||
return "readlink"
|
197
third_party/sycl/sycl_configure.bzl
vendored
Normal file
197
third_party/sycl/sycl_configure.bzl
vendored
Normal file
@ -0,0 +1,197 @@
|
||||
# -*- Python -*-
|
||||
"""SYCL autoconfiguration.
|
||||
`sycl_configure` depends on the following environment variables:
|
||||
|
||||
* HOST_CXX_COMPILER: The host C++ compiler
|
||||
* HOST_C_COMPILER: The host C compiler
|
||||
* COMPUTECPP_TOOLKIT_PATH: The path to the ComputeCpp toolkit.
|
||||
"""
|
||||
|
||||
_HOST_CXX_COMPILER = "HOST_CXX_COMPILER"
|
||||
_HOST_C_COMPILER= "HOST_C_COMPILER"
|
||||
_COMPUTECPP_TOOLKIT_PATH = "COMPUTECPP_TOOLKIT_PATH"
|
||||
|
||||
def _enable_sycl(repository_ctx):
|
||||
if "TF_NEED_OPENCL" in repository_ctx.os.environ:
|
||||
enable_sycl = repository_ctx.os.environ["TF_NEED_OPENCL"].strip()
|
||||
return enable_sycl == "1"
|
||||
return False
|
||||
|
||||
def auto_configure_fail(msg):
|
||||
"""Output failure message when auto configuration fails."""
|
||||
red = "\033[0;31m"
|
||||
no_color = "\033[0m"
|
||||
fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
|
||||
# END cc_configure common functions (see TODO above).
|
||||
|
||||
def find_c(repository_ctx):
|
||||
"""Find host C compiler."""
|
||||
c_name = "gcc"
|
||||
if _HOST_C_COMPILER in repository_ctx.os.environ:
|
||||
c_name = repository_ctx.os.environ[_HOST_C_COMPILER].strip()
|
||||
if c_name.startswith("/"):
|
||||
return c_name
|
||||
c = repository_ctx.which(c_name)
|
||||
if c == None:
|
||||
fail("Cannot find C compiler, please correct your path.")
|
||||
return c
|
||||
|
||||
def find_cc(repository_ctx):
|
||||
"""Find host C++ compiler."""
|
||||
cc_name = "g++"
|
||||
if _HOST_CXX_COMPILER in repository_ctx.os.environ:
|
||||
cc_name = repository_ctx.os.environ[_HOST_CXX_COMPILER].strip()
|
||||
if cc_name.startswith("/"):
|
||||
return cc_name
|
||||
cc = repository_ctx.which(cc_name)
|
||||
if cc == None:
|
||||
fail("Cannot find C++ compiler, please correct your path.")
|
||||
return cc
|
||||
|
||||
def find_computecpp_root(repository_ctx):
|
||||
"""Find ComputeCpp compiler."""
|
||||
sycl_name = ""
|
||||
if _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ:
|
||||
sycl_name = repository_ctx.os.environ[_COMPUTECPP_TOOLKIT_PATH].strip()
|
||||
if sycl_name.startswith("/"):
|
||||
return sycl_name
|
||||
fail( "Cannot find SYCL compiler, please correct your path")
|
||||
|
||||
def _check_lib(repository_ctx, toolkit_path, lib):
|
||||
"""Checks if lib exists under sycl_toolkit_path or fail if it doesn't.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
toolkit_path: The toolkit directory containing the libraries.
|
||||
ib: The library to look for under toolkit_path.
|
||||
"""
|
||||
lib_path = toolkit_path + "/" + lib
|
||||
if not repository_ctx.path(lib_path).exists:
|
||||
auto_configure_fail("Cannot find %s" % lib_path)
|
||||
|
||||
def _check_dir(repository_ctx, directory):
|
||||
"""Checks whether the directory exists and fail if it does not.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
directory: The directory to check the existence of.
|
||||
"""
|
||||
if not repository_ctx.path(directory).exists:
|
||||
auto_configure_fail("Cannot find dir: %s" % directory)
|
||||
|
||||
def _symlink_dir(repository_ctx, src_dir, dest_dir):
|
||||
"""Symlinks all the files in a directory.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
src_dir: The source directory.
|
||||
dest_dir: The destination directory to create the symlinks in.
|
||||
"""
|
||||
files = repository_ctx.path(src_dir).readdir()
|
||||
for src_file in files:
|
||||
repository_ctx.symlink(src_file, dest_dir + "/" + src_file.basename)
|
||||
|
||||
def _tpl(repository_ctx, tpl, substitutions={}, out=None):
|
||||
if not out:
|
||||
out = tpl.replace(":", "/")
|
||||
repository_ctx.template(
|
||||
out,
|
||||
Label("//third_party/sycl/%s.tpl" % tpl),
|
||||
substitutions)
|
||||
|
||||
def _file(repository_ctx, label):
|
||||
repository_ctx.template(
|
||||
label.replace(":", "/"),
|
||||
Label("//third_party/sycl/%s.tpl" % label),
|
||||
{})
|
||||
|
||||
_DUMMY_CROSSTOOL_BZL_FILE = """
|
||||
def error_sycl_disabled():
|
||||
fail("ERROR: Building with --config=sycl but TensorFlow is not configured " +
|
||||
"to build with SYCL support. Please re-run ./configure and enter 'Y' " +
|
||||
"at the prompt to build with SYCL support.")
|
||||
|
||||
native.genrule(
|
||||
name = "error_gen_crosstool",
|
||||
outs = ["CROSSTOOL"],
|
||||
cmd = "echo 'Should not be run.' && exit 1",
|
||||
)
|
||||
|
||||
native.filegroup(
|
||||
name = "crosstool",
|
||||
srcs = [":CROSSTOOL"],
|
||||
output_licenses = ["unencumbered"],
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
_DUMMY_CROSSTOOL_BUILD_FILE = """
|
||||
load("//crosstool:error_sycl_disabled.bzl", "error_sycl_disabled")
|
||||
|
||||
error_sycl_disabled()
|
||||
"""
|
||||
|
||||
def _create_dummy_repository(repository_ctx):
|
||||
# Set up BUILD file for sycl/.
|
||||
_file(repository_ctx, "sycl:build_defs.bzl")
|
||||
_file(repository_ctx, "sycl:BUILD")
|
||||
_file(repository_ctx, "sycl:platform.bzl")
|
||||
|
||||
# Create dummy files for the SYCL toolkit since they are still required by
|
||||
# tensorflow/sycl/platform/default/build_config:sycl.
|
||||
repository_ctx.file("sycl/include/sycl.hpp", "")
|
||||
repository_ctx.file("sycl/lib/libComputeCpp.so", "")
|
||||
|
||||
# If sycl_configure is not configured to build with SYCL support, and the user
|
||||
# attempts to build with --config=sycl, add a dummy build rule to intercept
|
||||
# this and fail with an actionable error message.
|
||||
repository_ctx.file("crosstool/error_sycl_disabled.bzl",
|
||||
_DUMMY_CROSSTOOL_BZL_FILE)
|
||||
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
|
||||
|
||||
|
||||
def _sycl_autoconf_imp(repository_ctx):
|
||||
"""Implementation of the sycl_autoconf rule."""
|
||||
if not _enable_sycl(repository_ctx):
|
||||
_create_dummy_repository(repository_ctx)
|
||||
else:
|
||||
# copy template files
|
||||
_file(repository_ctx, "sycl:build_defs.bzl")
|
||||
_file(repository_ctx, "sycl:BUILD")
|
||||
_file(repository_ctx, "sycl:platform.bzl")
|
||||
_file(repository_ctx, "crosstool:BUILD")
|
||||
_tpl(repository_ctx, "crosstool:computecpp",
|
||||
{
|
||||
"%{host_cxx_compiler}" : find_cc(repository_ctx),
|
||||
"%{host_c_compiler}" : find_c(repository_ctx),
|
||||
})
|
||||
|
||||
computecpp_root = find_computecpp_root(repository_ctx);
|
||||
_check_dir(repository_ctx, computecpp_root)
|
||||
|
||||
_tpl(repository_ctx, "crosstool:CROSSTOOL",
|
||||
{
|
||||
"%{computecpp_toolkit_path}" : computecpp_root,
|
||||
})
|
||||
|
||||
# symlink libraries
|
||||
_check_lib(repository_ctx, computecpp_root+"/lib", "libComputeCpp.so" )
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/lib", "sycl/lib")
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include")
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin")
|
||||
|
||||
sycl_configure = repository_rule(
|
||||
implementation = _sycl_autoconf_imp,
|
||||
local = True,
|
||||
)
|
||||
"""Detects and configures the SYCL toolchain.
|
||||
|
||||
Add the following to your WORKSPACE FILE:
|
||||
|
||||
```python
|
||||
sycl_configure(name = "local_config_sycl")
|
||||
```
|
||||
|
||||
Args:
|
||||
name: A unique name for this workspace rule.
|
||||
"""
|
@ -1,6 +1,9 @@
|
||||
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
|
||||
|
||||
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
|
||||
build:sycl --define=using_sycl=true
|
||||
|
||||
build --force_python=py$PYTHON_MAJOR_VERSION
|
||||
build --host_force_python=py$PYTHON_MAJOR_VERSION
|
||||
build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY
|
||||
|
Loading…
Reference in New Issue
Block a user