diff --git a/configure b/configure index 933bd573578..426071e48d0 100755 --- a/configure +++ b/configure @@ -126,6 +126,17 @@ GEN_GIT_SOURCE=tensorflow/tools/git/gen_git_source.py chmod a+x ${GEN_GIT_SOURCE} "${PYTHON_BIN_PATH}" ${GEN_GIT_SOURCE} --configure "${SOURCE_BASE_DIR}" +## Set up SYCL-related environment settings +while [ "$TF_NEED_OPENCL" == "" ]; do + read -p "Do you wish to build TensorFlow with OpenCL support? [y/N] " INPUT + case $INPUT in + [Yy]* ) echo "OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=1;; + [Nn]* ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;; + "" ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done + ## Set up Cuda-related environment settings while [ "$TF_NEED_CUDA" == "" ]; do @@ -139,12 +150,14 @@ while [ "$TF_NEED_CUDA" == "" ]; do done export TF_NEED_CUDA -if [ "$TF_NEED_CUDA" == "0" ]; then +export TF_NEED_SYCL +if [[ "$TF_NEED_CUDA" == "0" ]] && [[ "$TF_NEED_OPENCL" == "0" ]]; then echo "Configuration finished" bazel_clean_and_fetch exit fi +if [ "$TF_NEED_CUDA" == "1" ]; then # Set up which gcc nvcc should use as the host compiler while true; do fromuser="" @@ -346,6 +359,65 @@ EOF TF_CUDA_COMPUTE_CAPABILITIES="" done +# end of if "$TF_NEED_CUDA" == "1" +fi + +# OpenCL configuration + +if [ "$TF_NEED_OPENCL" == "1" ]; then +while true; do + # Configure the OPENCL version to use. + TF_OPENCL_VERSION="1.2" + + # Point to ComputeCPP root + if [ -z "$COMPUTECPP_PATH" ]; then + default_computecpp_path=/usr/local/computecpp + read -p "Please specify the location where ComputeCPP $TF_OPENCL_VERSION is installed. Refer to README.md for more details. [Default is $default_computecpp_path]: " COMPUTECPP_PATH + fromuser="1" + if [ -z "$COMPUTECPP_PATH" ]; then + COMPUTECPP_PATH=$default_computecpp_path + fi + fi + + if [ "$OSNAME" == "Linux" ]; then + SYCL_RT_LIB_PATH="lib/libComputeCpp.so" + fi + + if [ -e "${COMPUTECPP_PATH}/${SYCL_RT_LIB_PATH}" ]; then + break + fi + echo "Invalid path to SYCL $TF_OPENCL_VERSION library. ${COMPUTECPP_PATH}/${SYCL_RT_LIB_PATH} cannot be found" + + if [ -z "$fromuser" ]; then + exit 1 + fi + # Retry + TF_OPENCL_VERSION="" + COMPUTECPP_PATH="" +done + +cat > third_party/sycl/sycl.config <<EOF +# COMPUTECPP_PATH refers to the ComputeCPP toolkit. +COMPUTECPP_PATH="$COMPUTECPP_PATH" + +# The OpenCL version that should be used in this build +TF_OPENCL_VERSION=$TF_OPENCL_VERSION + +EOF + +export WARNING=$DO_NOT_SUBMIT_WARNING +perl -pi -e "s,#cxx_builtin_include_directory: {COMPUTECPP_INCLUDE},# \$ENV{WARNING}\ncxx_builtin_include_directory: \"${COMPUTECPP_PATH}\",s" third_party/sycl/crosstool/CROSSTOOL + +# Configure the platform name. +perl -pi -e "s,PLATFORM = \".*\",PLATFORM = \"$OSNAME\",s" third_party/sycl/platform.bzl + + +# Invoke the cuda_config.sh and set up the TensorFlow's canonical view of the Cuda libraries +(cd third_party/sycl; ./sycl_config.sh;) || exit -1 + +# end of if "$TF_NEED_OPENCL" == "1" +fi + bazel_clean_and_fetch echo "Configuration finished" diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h index c13f67ffcc7..8fd6597cb88 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.h +++ b/tensorflow/core/common_runtime/bfc_allocator.h @@ -295,6 +295,8 @@ class BFCAllocator : public VisitableAllocator { private: std::vector<AllocationRegion> regions_; }; + // Structures mutable after construction + mutable mutex lock_; // Returns 'bytes' rounded up to the next highest kMinAllocationSize. size_t RoundedBytes(size_t bytes); @@ -389,9 +391,6 @@ class BFCAllocator : public VisitableAllocator { std::unique_ptr<SubAllocator> suballocator_; string name_; - - // Structures mutable after construction - mutable mutex lock_; RegionManager region_manager_ GUARDED_BY(lock_); std::vector<Chunk> chunks_; diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index a4289112534..8fe4825aa6d 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -162,6 +162,8 @@ class DirectSession : public Session { protobuf::RepeatedPtrField<DebugTensorWatch> debug_tensor_watches; }; + mutex graph_def_lock_; + // Initializes the base execution state given the 'graph', // if not already initialized. Status MaybeInitializeExecutionState(const GraphDef& graph, @@ -227,7 +229,6 @@ class DirectSession : public Session { string session_handle_; bool graph_created_ GUARDED_BY(graph_def_lock_) = false; - mutex graph_def_lock_; GraphDef graph_def_ GUARDED_BY(graph_def_lock_); // The thread-pools to use for running ops. diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc index 2148f83fe57..423448773ae 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc @@ -28,6 +28,7 @@ namespace tensorflow { namespace { class FakeAllocator { + mutex mu_; public: FakeAllocator(size_t cap, int millis_to_wait) : memory_capacity_(cap), millis_to_wait_(millis_to_wait) {} @@ -57,7 +58,6 @@ class FakeAllocator { private: AllocatorRetry retry_; void* good_ptr_ = reinterpret_cast<void*>(0xdeadbeef); - mutex mu_; size_t memory_capacity_ GUARDED_BY(mu_); int millis_to_wait_; }; @@ -72,6 +72,7 @@ class FakeAllocator { // interesting part of their interaction with the allocator. This // class is the mechanism that imposes turn taking. class AlternatingBarrier { + mutex mu_; public: explicit AlternatingBarrier(int num_users) : num_users_(num_users), next_turn_(0), done_(num_users, false) {} @@ -109,7 +110,6 @@ class AlternatingBarrier { } } - mutex mu_; condition_variable cv_; int num_users_; int next_turn_ GUARDED_BY(mu_); @@ -118,6 +118,7 @@ class AlternatingBarrier { class GPUAllocatorRetryTest : public ::testing::Test { protected: + mutex mu_; GPUAllocatorRetryTest() {} void LaunchConsumerThreads(int num_consumers, int cap_needed) { @@ -173,7 +174,6 @@ class GPUAllocatorRetryTest : public ::testing::Test { std::vector<Thread*> consumers_; std::vector<int> consumer_count_; Notification notifier_; - mutex mu_; bool has_failed_ GUARDED_BY(mu_) = false; int count_ GUARDED_BY(mu_) = 0; }; diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/gpu/pool_allocator.h index b2f0265145f..437fea91155 100644 --- a/tensorflow/core/common_runtime/gpu/pool_allocator.h +++ b/tensorflow/core/common_runtime/gpu/pool_allocator.h @@ -45,6 +45,7 @@ class RoundUpInterface { // Size-limited pool of memory buffers obtained from a SubAllocator // instance. Pool eviction policy is LRU. class PoolAllocator : public VisitableAllocator { + mutex mutex_; public: // "pool_size_limit" is the maximum number of returned, re-usable // memory buffers to keep in the pool. If pool_size_limit == 0, the @@ -136,7 +137,6 @@ class PoolAllocator : public VisitableAllocator { size_t pool_size_limit_; std::unique_ptr<SubAllocator> allocator_; std::unique_ptr<RoundUpInterface> size_rounder_; - mutex mutex_; std::multimap<const size_t, PtrRecord*> pool_ GUARDED_BY(mutex_); PtrRecord* lru_head_ GUARDED_BY(mutex_) = nullptr; PtrRecord* lru_tail_ GUARDED_BY(mutex_) = nullptr; diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index f047ddb12a1..321ace9f465 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -125,6 +125,8 @@ class OpRegistry : public OpRegistryInterface { void ClearDeferredRegistrations(); private: + mutable mutex mu_; + // Ensures that all the functions in deferred_ get called, their OpDef's // registered, and returns with deferred_ empty. Returns true the first // time it is called. Prints a fatal log if any op registration fails. @@ -141,7 +143,6 @@ class OpRegistry : public OpRegistryInterface { Status RegisterAlreadyLocked(OpRegistrationDataFactory op_data_factory) const EXCLUSIVE_LOCKS_REQUIRED(mu_); - mutable mutex mu_; // Functions in deferred_ may only be called with mu_ held. mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_); // Values are owned. diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h index bb19f5dca04..040bb03f819 100644 --- a/tensorflow/core/framework/tracking_allocator.h +++ b/tensorflow/core/framework/tracking_allocator.h @@ -74,11 +74,11 @@ class TrackingAllocator : public Allocator { std::pair<size_t, size_t> GetSizesAndUnRef(); private: + mutex mu_; ~TrackingAllocator() override {} bool UnRef() EXCLUSIVE_LOCKS_REQUIRED(mu_); Allocator* allocator_; // not owned. - mutex mu_; // the number of calls to AllocateRaw that have not yet been matched // by a corresponding call to DeAllocateRaw, plus 1 if the Executor // has not yet read out the high watermark. diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc index 84f57517605..e91d9037cff 100644 --- a/tensorflow/core/kernels/barrier_ops.cc +++ b/tensorflow/core/kernels/barrier_ops.cc @@ -40,6 +40,7 @@ namespace tensorflow { namespace barrier { class Barrier : public ResourceBase { + mutex mu_; public: typedef std::vector<Tensor> Tuple; typedef std::function<void()> DoneCallback; @@ -417,7 +418,6 @@ class Barrier : public ResourceBase { private: typedef std::vector<PersistentTensor> PersistentTuple; - mutex mu_; bool closed_ GUARDED_BY(mu_); bool queue_closed_ GUARDED_BY(mu_); bool queue_cancelled_ GUARDED_BY(mu_); @@ -433,6 +433,7 @@ class Barrier : public ResourceBase { }; class BarrierOp : public OpKernel { + mutex mu_; public: explicit BarrierOp(OpKernelConstruction* context) : OpKernel(context), barrier_handle_set_(false) { @@ -511,7 +512,6 @@ class BarrierOp : public OpKernel { std::vector<TensorShape> value_component_shapes_; ContainerInfo cinfo_; - mutex mu_; PersistentTensor barrier_handle_ GUARDED_BY(mu_); bool barrier_handle_set_ GUARDED_BY(mu_); @@ -611,7 +611,9 @@ class TakeManyOp : public BarrierOpKernel { DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32}; // The first output is the insertion index, the second output is the key. DataTypeVector expected_outputs = {DT_INT64, DT_STRING}; - for (DataType dt : barrier->component_types()) { + for (auto it = barrier->component_types().begin(), + end = barrier->component_types().end(); it!= end; it++ ){ + const DataType dt = *it; expected_outputs.push_back(dt); } OP_REQUIRES_OK_ASYNC( diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h index f8c340a7691..4ee1601f342 100644 --- a/tensorflow/core/kernels/conditional_accumulator.h +++ b/tensorflow/core/kernels/conditional_accumulator.h @@ -65,7 +65,7 @@ class ConditionalAccumulator functor::SetZeroFunctor<Device, T> set_zero_functor_; Status ValidateShape(const Tensor* tensor) - EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + EXCLUSIVE_LOCKS_REQUIRED(mu_) { // Must be compatible with accumulated gradient if available if (counter_ > 0) { if (!accum_grad_->shape().IsSameSize(tensor->shape())) { @@ -98,7 +98,7 @@ class ConditionalAccumulator } void DivideAccumGradByCounter(OpKernelContext* ctx) override - EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + EXCLUSIVE_LOCKS_REQUIRED(mu_) { Tensor c(DataTypeToEnum<T>::value, {}); c.scalar<T>()() = TypeConverter<T, int>::ConvertUToT(this->counter_); this->accum_grad_->template flat<T>().device( @@ -113,7 +113,7 @@ class ConditionalAccumulator bool GetAndValidateTensorInputForApplyGrad(OpKernelContext* ctx, const Tensor** tensor) override - EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + EXCLUSIVE_LOCKS_REQUIRED(mu_) { // Get input gradient tensor const Tensor* grad_tensor; OP_REQUIRES_OK_BOOLEAN(ctx, ctx->input("gradient", &grad_tensor)); diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h index 05ee855daee..9992379640d 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.h +++ b/tensorflow/core/kernels/conditional_accumulator_base.h @@ -45,6 +45,8 @@ namespace tensorflow { * (3) the internal global_step value (current_global_step_) is incremented by 1 */ class ConditionalAccumulatorBase : public ResourceBase { + protected: + mutex mu_; public: // Args: // dtype: The datatype of the gradients to be accumulated. @@ -125,7 +127,6 @@ class ConditionalAccumulatorBase : public ResourceBase { const DataType dtype_; const PartialTensorShape shape_; const string name_; - mutex mu_; int counter_ GUARDED_BY(mu_); int64 current_global_step_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h index 33c2d596c8b..0a64a857cdb 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.h +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -43,6 +43,7 @@ namespace tensorflow { * ConditionalAccumulatorBase (via sub-class's Creator) and returns its handle. */ class ConditionalAccumulatorBaseOp : public OpKernel { + mutex mu_; public: explicit ConditionalAccumulatorBaseOp(OpKernelConstruction* context) : OpKernel(context), accumulator_handle_set_(false) { @@ -109,7 +110,6 @@ class ConditionalAccumulatorBaseOp : public OpKernel { return Status::OK(); } - mutex mu_; PersistentTensor accumulator_handle_ GUARDED_BY(mu_); bool accumulator_handle_set_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h index 5ad6b1fd4a1..3d1953d7f4c 100644 --- a/tensorflow/core/kernels/cwise_ops_common.h +++ b/tensorflow/core/kernels/cwise_ops_common.h @@ -82,8 +82,13 @@ class BinaryOp : public BinaryOpShared { if (!ctx->status().ok()) return; Tensor* out = state.out; BCast* bcast = &state.bcast; +#if TENSORFLOW_USE_SYCL + decltype(state.in0) in0 = state.in0; + decltype(state.in1) in1 = state.in1; +#else auto& in0 = state.in0; auto& in1 = state.in1; +#endif if (state.out_num_elements == 0) { return; } diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h index 79b479b44b5..6b2043e5a32 100644 --- a/tensorflow/core/kernels/queue_base.h +++ b/tensorflow/core/kernels/queue_base.h @@ -83,6 +83,7 @@ class QueueBase : public QueueInterface { int64 index); protected: + mutex mu_; enum Action { kEnqueue, kDequeue }; enum RunResult { kNoProgress, kProgress, kComplete }; @@ -143,7 +144,6 @@ class QueueBase : public QueueInterface { const DataTypeVector component_dtypes_; const std::vector<TensorShape> component_shapes_; const string name_; - mutex mu_; bool closed_ GUARDED_BY(mu_); struct Attempt; diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h index 7694827854c..a21be2c389d 100644 --- a/tensorflow/core/kernels/queue_op.h +++ b/tensorflow/core/kernels/queue_op.h @@ -34,6 +34,7 @@ namespace tensorflow { // Defines a QueueOp, an abstract class for Queue construction ops. class QueueOp : public OpKernel { + mutex mu_; public: QueueOp(OpKernelConstruction* context) : OpKernel(context), queue_handle_set_(false) { @@ -94,7 +95,6 @@ class QueueOp : public OpKernel { return Status::OK(); } - mutex mu_; PersistentTensor queue_handle_ GUARDED_BY(mu_); bool queue_handle_set_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h index 89560094af6..73bd3b47e48 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator.h +++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h @@ -83,7 +83,7 @@ class SparseConditionalAccumulator Status ValidateShape( std::tuple<const Tensor*, const Tensor*, const Tensor*>* tensor, - bool has_known_shape) EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + bool has_known_shape) EXCLUSIVE_LOCKS_REQUIRED(mu_) { const Tensor* tensor_idx = std::get<0>(*tensor); const Tensor* tensor_val = std::get<1>(*tensor); const Tensor* tensor_shape = std::get<2>(*tensor); diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index 7835fd7bbc1..96bff2c95ed 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -123,6 +123,7 @@ TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); // multiple reads of that index in the forward phase. // class TensorArray : public ResourceBase { + mutex mu_; public: static std::atomic<int64> tensor_array_counter; @@ -338,8 +339,6 @@ class TensorArray : public ResourceBase { const DataType dtype_; Tensor handle_; - mutex mu_; - // Marks that the tensor_array_ has been cleared. bool closed_ GUARDED_BY(mu_); diff --git a/tensorflow/core/lib/monitoring/collection_registry.cc b/tensorflow/core/lib/monitoring/collection_registry.cc index 47112279cff..01d643fbcca 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.cc +++ b/tensorflow/core/lib/monitoring/collection_registry.cc @@ -45,7 +45,9 @@ void Collector::CollectMetricDescriptor( metric_descriptor->name = metric_def->name().ToString(); metric_descriptor->description = metric_def->description().ToString(); - for (const StringPiece label_name : metric_def->label_descriptions()) { + for (auto it = metric_def->label_descriptions().begin(), + end = metric_def->label_descriptions().end() ; it!=end ;it++ ) { + const StringPiece label_name = *it; metric_descriptor->label_names.push_back(label_name.ToString()); } diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index 3da2439238f..ed957b9ae45 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -121,6 +121,7 @@ class MetricCollectorGetter { // // This class is thread-safe. class CollectionRegistry { + mutable mutex mu_; public: ~CollectionRegistry() = default; @@ -176,8 +177,6 @@ class CollectionRegistry { // TF environment, mainly used for timestamping. Env* const env_; - mutable mutex mu_; - // Information required for collection. struct CollectionInfo { const AbstractMetricDef* const metric_def; @@ -227,6 +226,7 @@ inline void CollectValue(const int64& value, Point* const point) { // // This class is thread-safe. class Collector { + mutable mutex mu_; public: Collector(const uint64 collection_time_millis) : collected_metrics_(new CollectedMetrics()), @@ -260,7 +260,6 @@ class Collector { LOCKS_EXCLUDED(mu_); private: - mutable mutex mu_; std::unique_ptr<CollectedMetrics> collected_metrics_ GUARDED_BY(mu_); const uint64 collection_time_millis_; diff --git a/tensorflow/core/lib/monitoring/counter.h b/tensorflow/core/lib/monitoring/counter.h index e76057b980a..0ea50932dd9 100644 --- a/tensorflow/core/lib/monitoring/counter.h +++ b/tensorflow/core/lib/monitoring/counter.h @@ -78,6 +78,7 @@ class CounterCell { // This class is thread-safe. template <int NumLabels> class Counter { + mutable mutex mu_; public: ~Counter() { // Deleted here, before the metric_def is destroyed. @@ -111,8 +112,6 @@ class Counter { } })) {} - mutable mutex mu_; - // The metric definition. This will be used to identify the metric when we // register it for collection. const MetricDef<MetricKind::kCumulative, int64, NumLabels> metric_def_; diff --git a/tensorflow/stream_executor/machine_manager.h b/tensorflow/stream_executor/machine_manager.h index 65396dd1ff5..bf95bc74713 100644 --- a/tensorflow/stream_executor/machine_manager.h +++ b/tensorflow/stream_executor/machine_manager.h @@ -60,6 +60,9 @@ namespace gputools { // // Thread-safe. class MachineManager { + // Mutex that guards the initialization of the machine manager static + // variable. + static mutex mu_; public: // Inspects the host to determine the preferred GPU execution platform. // To force OpenCL from a build target on a machine that has both OpenCL and @@ -171,10 +174,6 @@ class MachineManager { // Returns the NUMA node association for the StreamExecutor. int ExecutorToNumaNode(const StreamExecutor *stream_exec) const; - // Mutex that guards the initialization of the machine manager static - // variable. - static mutex mu_; - // Singleton MachineManager value -- assignment to this is protected by a // static singleton guard clause. static MachineManager *singleton_ GUARDED_BY(mu_); diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index ac2a22ee548..6e3def96902 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -14,7 +14,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): # These lines need to be changed when updating Eigen. They are parsed from # this file by the cmake and make builds to determine the eigen version and # hash. - eigen_version = "aad63574941c" + eigen_version = "ab6d16a84626" eigen_sha256 = "" native.new_http_archive( diff --git a/third_party/sycl/BUILD b/third_party/sycl/BUILD new file mode 100755 index 00000000000..4aadc808302 --- /dev/null +++ b/third_party/sycl/BUILD @@ -0,0 +1,44 @@ +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +load("//third_party/sycl:build_defs.bzl", "if_sycl") +load("platform", "sycl_library_path") +load("platform", "sycl_static_library_path") + +load("platform", "readlink_command") + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "using_sycl", + values = { + "define": "using_sycl=true", + }, +) + +cc_library( + name = "sycl_headers", + hdrs = glob([ + "**/*.h", + ]), + includes = [".", "include"], +) + +cc_library( + name = "syclrt", + srcs = [ + sycl_library_path("ComputeCpp") + ], + data = [ + sycl_library_path("ComputeCpp") + ], + includes = ["include/"], + linkstatic = 1, +) + +cc_library( + name = "sycl", + deps = if_sycl([ + ":sycl_headers", + ":syclrt", + ]), +) diff --git a/third_party/sycl/build_defs.bzl b/third_party/sycl/build_defs.bzl new file mode 100755 index 00000000000..1aaeb741a21 --- /dev/null +++ b/third_party/sycl/build_defs.bzl @@ -0,0 +1,10 @@ +# Macros for building SYCL code. +def if_sycl(if_true, if_false = []): + """Shorthand for select()'ing on whether we're building with SYCL. + Returns a select statement which evaluates to if_true if we're building + with SYCL enabled. Otherwise, the select statement evaluates to if_false. + """ + return select({ + "//third_party/sycl:using_sycl": if_true, + "//conditions:default": if_false + }) diff --git a/third_party/sycl/crosstool/BUILD b/third_party/sycl/crosstool/BUILD new file mode 100755 index 00000000000..ec0070e71da --- /dev/null +++ b/third_party/sycl/crosstool/BUILD @@ -0,0 +1,29 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "local|compiler": ":cc-compiler-local", + }, +) + +cc_toolchain( + name = "cc-compiler-local", + all_files = ":empty", + compiler_files = ":empty", + cpu = "local", + dwp_files = ":empty", + dynamic_runtime_libs = [":empty"], + linker_files = ":empty", + objcopy_files = ":empty", + static_runtime_libs = [":empty"], + strip_files = ":empty", + supports_param_files = 0, +) + +filegroup( + name = "empty", + srcs = [], +) diff --git a/third_party/sycl/crosstool/CROSSTOOL b/third_party/sycl/crosstool/CROSSTOOL new file mode 100755 index 00000000000..a96d33d522f --- /dev/null +++ b/third_party/sycl/crosstool/CROSSTOOL @@ -0,0 +1,82 @@ +major_version: "local" +minor_version: "" +default_target_cpu: "same_as_host" + +default_toolchain { + cpu: "k8" + toolchain_identifier: "local_linux" +} + +toolchain { + abi_version: "local" + abi_libc_version: "local" + builtin_sysroot: "" + compiler: "compiler" + host_system_name: "local" + needsPic: true + supports_gold_linker: false + supports_incremental_linker: false + supports_fission: false + supports_interface_shared_objects: false + supports_normalizing_ar: false + supports_start_end_lib: false + supports_thin_archives: false + target_libc: "local" + target_cpu: "local" + target_system_name: "local" + toolchain_identifier: "local_linux" + + tool_path { name: "ar" path: "/usr/bin/ar" } + tool_path { name: "compat-ld" path: "/usr/bin/ld" } + tool_path { name: "cpp" path: "/usr/bin/cpp" } + tool_path { name: "dwp" path: "/usr/bin/dwp" } + tool_path { name: "gcc" path: "computecpp" } + # Use "-std=c++11" for nvcc. For consistency, force both the host compiler + # and the device compiler to use "-std=c++11". + cxx_flag: "-std=c++11" + linker_flag: "-lstdc++" + linker_flag: "-B/usr/bin/" + + # TODO(bazel-team): In theory, the path here ought to exactly match the path + # used by gcc. That works because bazel currently doesn't track files at + # absolute locations and has no remote execution, yet. However, this will need + # to be fixed, maybe with auto-detection? + cxx_builtin_include_directory: "/usr/lib/gcc/" + cxx_builtin_include_directory: "/usr/lib" + cxx_builtin_include_directory: "/usr/lib64" + cxx_builtin_include_directory: "/usr/local/include" + cxx_builtin_include_directory: "/usr/include" + + #cxx_builtin_include_directory: {COMPUTECPP_INCLUDE} + + tool_path { name: "gcov" path: "/usr/bin/gcov" } + + # C(++) compiles invoke the compiler (as that is the one knowing where + # to find libraries), but we provide LD so other rules can invoke the linker. + tool_path { name: "ld" path: "/usr/bin/ld" } + + tool_path { name: "nm" path: "/usr/bin/nm" } + tool_path { name: "objcopy" path: "/usr/bin/objcopy" } + objcopy_embed_flag: "-I" + objcopy_embed_flag: "binary" + tool_path { name: "objdump" path: "/usr/bin/objdump" } + tool_path { name: "strip" path: "/usr/bin/strip" } + + # Make C++ compilation deterministic. Use linkstamping instead of these + # compiler symbols. + unfiltered_cxx_flag: "-Wno-builtin-macro-redefined" + unfiltered_cxx_flag: "-D__DATE__=\"redacted\"" + unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\"" + unfiltered_cxx_flag: "-D__TIME__=\"redacted\"" + + # All warnings are enabled. Maybe enable -Werror as well? + compiler_flag: "-Wall" + + # Anticipated future default. + linker_flag: "-Wl,-no-as-needed" + # Stamp the binary with a unique identifier. + linker_flag: "-Wl,--build-id=md5" + linker_flag: "-Wl,--hash-style=gnu" + + linking_mode_flags { mode: DYNAMIC } +} diff --git a/third_party/sycl/crosstool/computecpp b/third_party/sycl/crosstool/computecpp new file mode 100755 index 00000000000..0b8427fd0ae --- /dev/null +++ b/third_party/sycl/crosstool/computecpp @@ -0,0 +1,61 @@ +#!/usr/bin/env python2.7 + +from argparse import ArgumentParser +import os +import subprocess +import re +import sys +import pipes + +CPU_CXX_COMPILER = ('/usr/bin/clang++-3.6') +CPU_C_COMPILER = ('/usr/bin/clang-3.6') + +CURRENT_DIR = os.path.dirname(sys.argv[0]) +COMPUTECPP_ROOT = CURRENT_DIR +"/../" +COMPUTECPP_DRIVER= COMPUTECPP_ROOT+"bin/compute++" +COMPUTECPP_INCLUDE = COMPUTECPP_ROOT+"include" + +def main(): + computecpp_compiler_flags = [""] + computecpp_compiler_flags = [flag for flag in sys.argv[1:]] + computecpp_compiler_flags = computecpp_compiler_flags + ["-D_GLIBCXX_USE_CXX11_ABI=0"] + + output_file_index = computecpp_compiler_flags.index("-o") +1 + output_file_name = computecpp_compiler_flags[output_file_index] + + if(output_file_index == 1): + # we are linking + return subprocess.call([CPU_CXX_COMPILER] +computecpp_compiler_flags ) + + # find what we compile + compiling_cpp = 0 + if("-c" in computecpp_compiler_flags): + compiled_file_index = computecpp_compiler_flags.index("-c") +1 + compited_file_name = computecpp_compiler_flags[compiled_file_index] + if(compited_file_name.endswith(('.cc', '.c++', '.cpp', '.CPP', '.C', '.cxx'))): + compiling_cpp = 1; + + if(compiling_cpp == 1): + filename, file_extension = os.path.splitext(output_file_name) + bc_out = filename + ".sycl" + + computecpp_compiler_flags = ['-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable','-I', COMPUTECPP_INCLUDE,'-isystem', + COMPUTECPP_INCLUDE, "-std=c++11", "-sycl", "-emit-llvm", "-no-serial-memop"] + computecpp_compiler_flags + + # dont want that in case of compiling with computecpp first + host_compiler_flags = [""] + host_compiler_flags = [flag for flag in sys.argv[1:] + if not flag.startswith(('-MF','-MD',)) + if not ".d" in flag] + + x = subprocess.call([COMPUTECPP_DRIVER] +computecpp_compiler_flags ) + if(x == 0): + host_compiler_flags = ['-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, "--include",bc_out] + host_compiler_flags + return subprocess.call([CPU_CXX_COMPILER] +host_compiler_flags ) + return x + else: + # compile for C + return subprocess.call([CPU_C_COMPILER] +computecpp_compiler_flags) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/third_party/sycl/platform.bzl b/third_party/sycl/platform.bzl new file mode 100755 index 00000000000..8905b496b3b --- /dev/null +++ b/third_party/sycl/platform.bzl @@ -0,0 +1,17 @@ +SYCL_VERSION = "" +PLATFORM = "" + +def sycl_sdk_version(): + return SYCL_VERSION + +def sycl_library_path(name, version = sycl_sdk_version()): + if not version: + return "lib/lib{}.so".format(name) + else: + return "lib/lib{}.so.{}".format(name, version) + +def sycl_static_library_path(name): + return "lib/lib{}_static.a".format(name) + +def readlink_command(): + return "readlink" diff --git a/third_party/sycl/sycl_config.sh b/third_party/sycl/sycl_config.sh new file mode 100755 index 00000000000..7fa93110708 --- /dev/null +++ b/third_party/sycl/sycl_config.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +# A simple script to configure the SYCL tree needed for the TensorFlow OpenCL +# build. We need both COMPUTECPP toolkit $TF_OPENCL_VERSION. +# Useage: +# * User edit sycl.config to point ComputeCPP toolkit to its local path +# * run sycl_config.sh to generate symbolic links in the source tree to reflect +# * the file organizations needed by TensorFlow. + +print_usage() { +cat << EOF +Usage: $0 [--check] + Configure TensorFlow's canonical view of SYCL libraries using sycl.config. +Arguments: + --check: Only check that the proper SYCL dependencies has already been + properly configured in the source tree. It also creates symbolic links to + the files in the gen-tree to make bazel happy. +EOF +} + +CHECK_ONLY=0 +# Parse the arguments. Add more arguments as the "case" line when needed. +while [[ $# -gt 0 ]]; do + argument="$1" + shift + case $argument in + --check) + CHECK_ONLY=1 + ;; + *) + echo "Error: unknown arguments" + print_usage + exit -1 + ;; + esac +done + +source sycl.config || exit -1 + +OUTPUTDIR=${OUTPUTDIR:-../..} +COMPUTECPP_PATH=${COMPUTECPP_PATH:-/usr/local/computecpp} + +# An error message when the SYCL toolkit is not found +function SYCLError { + echo ERROR: $1 +cat << EOF +############################################################################## +############################################################################## +SYCL $TF_OPENCL_VERSION toolkit is missing. +1. Download and install the ComputeCPP $TF_OPENCL_VERSION toolkit; +2. Run configure from the root of the source tree, before rerunning bazel; +Please refer to README.md for more details. +############################################################################## +############################################################################## +EOF + exit -1 +} + +# Check that the SYCL libraries have already been properly configured in the source tree. +# We still need to create links to the gen-tree to make bazel happy. +function CheckAndLinkToSrcTree { + ERROR_FUNC=$1 + FILE=$2 + if test ! -e $FILE; then + $ERROR_FUNC "$PWD/$FILE cannot be found" + fi + + # Link the output file to the source tree, avoiding self links if they are + # the same. This could happen if invoked from the source tree by accident. + if [ ! $($READLINK_CMD -f $PWD) == $($READLINK_CMD -f $OUTPUTDIR/third_party/sycl) ]; then + mkdir -p $(dirname $OUTPUTDIR/third_party/sycl/$FILE) + ln -sf $PWD/$FILE $OUTPUTDIR/third_party/sycl/$FILE + fi +} + +OSNAME=`uname -s` +if [ "$OSNAME" == "Linux" ]; then + SYCL_LIB_PATH="lib" + SYCL_RT_LIB_PATH="lib/libComputeCpp.so" + SYCL_RT_LIB_STATIC_PATH="lib/libComputeCpp.a" + READLINK_CMD="readlink" +fi + +if [ "$CHECK_ONLY" == "1" ]; then + CheckAndLinkToSrcTree SYCLError include/SYCL/sycl.h + CheckAndLinkToSrcTree SYCLError $SYCL_RT_LIB_STATIC_PATH + CheckAndLinkToSrcTree CudaError $SYCL_RT_LIB_PATH + exit 0 +fi + +# Actually configure the source tree for TensorFlow's canonical view of SYCL +# libraries. + +if test ! -e ${COMPUTECPP_PATH}/${SYCL_RT_LIB_PATH}; then + SYCLError "cannot find ${COMPUTECPP_PATH}/${SYCL_RT_LIB_PATH}" +fi + +# Helper function to build symbolic links for all files in a directory. +function LinkOneDir { + SRC_PREFIX=$1 + DST_PREFIX=$2 + SRC_DIR=$3 + DST_DIR=$(echo $SRC_DIR | sed "s,^$SRC_PREFIX,$DST_PREFIX,") + mkdir -p $DST_DIR + FILE_LIST=$(find -L $SRC_DIR -maxdepth 1 -type f) + if test "$FILE_LIST" != ""; then + ln -sf $FILE_LIST $DST_DIR/ || exit -1 + fi +} +export -f LinkOneDir + +# Build links for all files in the directory, including subdirectories. +function LinkAllFiles { + SRC_DIR=$1 + DST_DIR=$2 + find -L $SRC_DIR -type d | xargs -I {} bash -c "LinkOneDir $SRC_DIR $DST_DIR {}" || exit -1 +} + +# Set up the symbolic links for SYCL toolkit. We link at individual file level, +# not at the directory level. +# This is because the external library may have a different file layout from our desired structure. +mkdir -p $OUTPUTDIR/third_party/sycl +echo "Setting up SYCL include" +LinkAllFiles ${COMPUTECPP_PATH}/include $OUTPUTDIR/third_party/sycl/include || exit -1 +echo "Setting up SYCL ${SYCL_LIB_PATH}" +LinkAllFiles ${COMPUTECPP_PATH}/${SYCL_LIB_PATH} $OUTPUTDIR/third_party/sycl/${SYCL_LIB_PATH} || exit -1 +echo "Setting up SYCL bin" +LinkAllFiles ${COMPUTECPP_PATH}/bin $OUTPUTDIR/third_party/sycl/bin || exit -1 diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template index 58dd7434a89..bdbc88ba395 100644 --- a/tools/bazel.rc.template +++ b/tools/bazel.rc.template @@ -1,6 +1,9 @@ build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true +build:sycl --crosstool_top=//third_party/sycl/crosstool:toolchain +build:sycl --define=using_sycl=true + build --force_python=py$PYTHON_MAJOR_VERSION build --host_force_python=py$PYTHON_MAJOR_VERSION build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY