Delete SYCL support

See discussion here:
https://github.com/tensorflow/tensorflow/issues/41809#issuecomment-688021592

Fixes #41809.

PiperOrigin-RevId: 331808169
Change-Id: Ib0861cf250c92c20f0e8a22adce89a4dc4d3548a
This commit is contained in:
Sanjoy Das 2020-09-15 11:02:26 -07:00 committed by TensorFlower Gardener
parent 1009006e3e
commit 3cbb507689
213 changed files with 29 additions and 7248 deletions

View File

@ -46,10 +46,6 @@
# using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm).
# sycl: Build with SYCL support.
# sycl_nodouble:
# sycl_asan:
# sycl_trisycl:
# mkl: Enable full mkl support.
# tensorrt: Enable Tensorrt support.
# ngraph: Enable ngraph support.
@ -214,19 +210,6 @@ build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --action_env TF_NEED_ROCM=1
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl --define=using_sycl=true
build:sycl --action_env TF_NEED_OPENCL_SYCL=1
build:sycl_nodouble --config=sycl
build:sycl_nodouble --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE
build:sycl_nodouble --config=sycl
build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address
build:sycl_nodouble --config=sycl
build:sycl_trisycl --define=using_trisycl=true
# Options extracted from configure script
build:ngraph --define=with_ngraph_support=true
build:numa --define=with_numa_support=true

View File

@ -38,9 +38,6 @@ _DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_TENSORRT_VERSION = '6'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
@ -1114,62 +1111,6 @@ def set_host_c_compiler(environ_cp):
write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler)
def set_computecpp_toolkit_path(environ_cp):
"""Set COMPUTECPP_TOOLKIT_PATH."""
def toolkit_exists(toolkit_path):
"""Check if a computecpp toolkit path is valid."""
if is_linux():
sycl_rt_lib_path = 'lib/libComputeCpp.so'
else:
sycl_rt_lib_path = ''
sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path)
exists = os.path.exists(sycl_rt_lib_path_full)
if not exists:
print('Invalid SYCL %s library path. %s cannot be found' %
(_TF_OPENCL_VERSION, sycl_rt_lib_path_full))
return exists
computecpp_toolkit_path = prompt_loop_or_load_from_env(
environ_cp,
var_name='COMPUTECPP_TOOLKIT_PATH',
var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH,
ask_for_var=(
'Please specify the location where ComputeCpp for SYCL %s is '
'installed.' % _TF_OPENCL_VERSION),
check_success=toolkit_exists,
error_msg='Invalid SYCL compiler path. %s cannot be found.',
suppress_default_error=True)
write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
computecpp_toolkit_path)
def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR."""
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
'[Default is %s]: ') % (
_DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
_DEFAULT_TRISYCL_INCLUDE_DIR)
if os.path.exists(trisycl_include_dir):
break
print('Invalid triSYCL include directory, %s cannot be found' %
(trisycl_include_dir))
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def system_specific_test_config(environ_cp):
"""Add default build and test flags required for TF tests to bazelrc."""
write_to_bazelrc('test --flaky_test_attempts=3')
@ -1397,8 +1338,6 @@ def main():
setup_python(environ_cp)
if is_windows():
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'
@ -1415,21 +1354,6 @@ def main():
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
write_to_bazelrc('build --config=xla')
set_action_env_var(
environ_cp,
'TF_NEED_OPENCL_SYCL',
'OpenCL SYCL',
False,
bazel_config_name='sycl')
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
set_host_cxx_compiler(environ_cp)
set_host_c_compiler(environ_cp)
set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True)
if environ_cp.get('TF_NEED_COMPUTECPP') == '1':
set_computecpp_toolkit_path(environ_cp)
else:
set_trisycl_include_dir(environ_cp)
set_action_env_var(
environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm')
if (environ_cp.get('TF_NEED_ROCM') == '1' and
@ -1528,17 +1452,15 @@ def main():
# use it for the CPU build.
set_tf_download_clang(environ_cp)
# SYCL / ROCm / CUDA are mutually exclusive.
# ROCm / CUDA are mutually exclusive.
# At most 1 GPU platform can be configured.
gpu_platform_count = 0
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
gpu_platform_count += 1
if environ_cp.get('TF_NEED_ROCM') == '1':
gpu_platform_count += 1
if environ_cp.get('TF_NEED_CUDA') == '1':
gpu_platform_count += 1
if gpu_platform_count >= 2:
raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. '
raise UserInputError('CUDA / ROCm are mututally exclusive. '
'At most 1 GPU platform can be configured.')
set_cc_opt_flags(environ_cp)

View File

@ -88,7 +88,6 @@ cc_library(
deps = [
":core_cpu",
"//tensorflow/core/common_runtime/gpu:gpu_runtime",
"//tensorflow/core/common_runtime/sycl:sycl_runtime",
] + if_tpu(["//tensorflow/core/tpu:tpu_runtime"]),
)

View File

@ -123,7 +123,6 @@ class Registrar {
//
// The default priority values for built-in devices is:
// GPU: 210
// SYCL: 200
// GPUCompatibleCPU: 70
// ThreadPoolDevice: 60
// Default: 50

View File

@ -1965,7 +1965,6 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib,
->set_constant_folding(RewriterConfig::OFF);
(*options.config.mutable_device_count())["CPU"] = 2;
(*options.config.mutable_device_count())["GPU"] = 0;
(*options.config.mutable_device_count())["SYCL"] = 0;
auto* p = options.config.add_session_inter_op_thread_pool();
if (use_global_pools) p->set_global_name("large pool");

View File

@ -175,16 +175,10 @@ static void TestHWAccelerator(bool enableHWTrace) {
test::FillValues<float>(&x_tensor, {1, 1});
Node* x = test::graph::Constant(&graph, x_tensor);
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
#ifdef TENSORFLOW_USE_SYCL
x->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
#endif // TENSORFLOW_USE_SYCL
// y = A * x
Node* y = test::graph::Matmul(&graph, a, x, false, false);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
#ifdef TENSORFLOW_USE_SYCL
y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:SYCL:0");
#endif // TENSORFLOW_USE_SYCL
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
@ -195,9 +189,6 @@ static void TestHWAccelerator(bool enableHWTrace) {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 1;
(*options.config.mutable_device_count())["GPU"] = 1;
#ifdef TENSORFLOW_USE_SYCL
(*options.config.mutable_device_count())["SYCL"] = 1;
#endif // TENSORFLOW_USE_SYCL
options.config.set_allow_soft_placement(true);
options.config.mutable_graph_options()->set_build_cost_model(1);
std::unique_ptr<Session> session(NewSession(options));

View File

@ -48,13 +48,12 @@ struct EndpointEq {
static Status ProcessMemoryTypes(
const DeviceType& device_type, const Graph* g,
const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL) {
// On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
// compatible.
if (device_type != DEVICE_GPU) {
// On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible.
return Status::OK();
}
// For GPU and SYCL device, HOST_MEMORY and DEVICE_MEMORY is not
// compatible. I.e., a conversion/transfer must be done.
// For GPU, HOST_MEMORY and DEVICE_MEMORY is not compatible. I.e., a
// conversion/transfer must be done.
//
// {node id, slot id} -> memory type.
typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq>

View File

@ -34,9 +34,6 @@ TEST(MemoryTypeChecker, Int32OK) {
// There is a kernel for adding two int32s on host memory.
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
#endif // TENSORFLOW_USE_SYCL
delete g;
}
@ -56,15 +53,6 @@ TEST(MemoryTypeChecker, Int32NotOk) {
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_GPU, "/device:GPU:0", g));
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_GPU, g));
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
// There is no kernel for casting int32/host memory to float/device
// memory.
EXPECT_TRUE(errors::IsInternal(ValidateMemoryTypes(DEVICE_SYCL, g)));
// But we can insert _HostSend/_HostRecv to ensure the invariant.
TF_EXPECT_OK(EnsureMemoryTypes(DEVICE_SYCL, "/device:SYCL:0", g));
TF_EXPECT_OK(ValidateMemoryTypes(DEVICE_SYCL, g));
#endif // TENSORFLOW_USE_SYCL
delete g;
}
@ -86,12 +74,6 @@ TEST(MemoryTypeChecker, MemoryTypeForOutput) {
// int Switch's output on GPU has HOST_MEMORY constraint.
EXPECT_EQ(memory_type, HOST_MEMORY);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
auto si = test::graph::Switch(g, test::graph::Constant(g, vi), pred);
TF_EXPECT_OK(MemoryTypeForOutput(DEVICE_SYCL, g, si, 0, &memory_type));
// int Switch's output on GPU has HOST_MEMORY constraint.
EXPECT_EQ(memory_type, HOST_MEMORY);
#endif // TENSORFLOW_USE_SYCL
delete g;
}

View File

@ -91,11 +91,6 @@ class RenamedDevice : public Device {
return underlying_device_->has_eigen_cpu_device();
}
#ifdef TENSORFLOW_USE_SYCL
const Eigen::SyclDevice* eigen_sycl_device() const override {
return underlying_device_->eigen_sycl_device();
}
#endif
PerOpGpuDevice* MakeGpuDevice() override {
return underlying_device_->MakeGpuDevice();

View File

@ -1,46 +0,0 @@
load(
"//tensorflow:tensorflow.bzl",
"if_not_windows",
"tf_copts",
)
load(
"//tensorflow/core/platform:rules_cc.bzl",
"cc_library",
)
package(
default_visibility = [
"//tensorflow:internal",
],
features = ["-parse_headers"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "sycl_runtime",
srcs = if_not_windows([
"sycl_allocator.cc",
"sycl_device.cc",
"sycl_device_context.cc",
"sycl_device_factory.cc",
]),
hdrs = if_not_windows([
"sycl_allocator.h",
"sycl_device.h",
"sycl_util.h",
"sycl_device_context.h",
]),
copts = tf_copts(),
linkstatic = 0,
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/common_runtime:core_cpu",
"//tensorflow/core/common_runtime:core_cpu_internal",
"//third_party/eigen3",
"@local_config_sycl//sycl",
],
alwayslink = 0,
)

View File

@ -1,92 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef TENSORFLOW_USE_SYCL
#include "tensorflow/core/common_runtime/sycl/sycl_allocator.h"
namespace tensorflow {
SYCLAllocator::SYCLAllocator(Eigen::QueueInterface* queue)
: sycl_device_(new Eigen::SyclDevice(queue)) {
cl::sycl::queue& sycl_queue = sycl_device_->sycl_queue();
const cl::sycl::device& device = sycl_queue.get_device();
stats_.bytes_limit =
device.get_info<cl::sycl::info::device::max_mem_alloc_size>();
}
SYCLAllocator::~SYCLAllocator() {
if (sycl_device_) {
delete sycl_device_;
}
}
string SYCLAllocator::Name() { return "device:SYCL"; }
void* SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
mutex_lock lock(mu_);
assert(sycl_device_);
if (num_bytes == 0) {
// Cannot allocate no bytes in SYCL, so instead allocate a single byte
num_bytes = 1;
}
auto p = sycl_device_->allocate(num_bytes);
const auto& allocated_buffer = sycl_device_->get_sycl_buffer(p);
const std::size_t bytes_allocated = allocated_buffer.get_range().size();
++stats_.num_allocs;
stats_.bytes_in_use += bytes_allocated;
stats_.max_bytes_in_use =
std::max<int64>(stats_.max_bytes_in_use, stats_.bytes_in_use);
stats_.max_alloc_size =
std::max<int64>(stats_.max_alloc_size, bytes_allocated);
return p;
}
void SYCLAllocator::DeallocateRaw(void* ptr) {
mutex_lock lock(mu_);
if (sycl_device_) {
const auto& buffer_to_delete = sycl_device_->get_sycl_buffer(ptr);
const std::size_t dealloc_size = buffer_to_delete.get_range().size();
stats_.bytes_in_use -= dealloc_size;
sycl_device_->deallocate(ptr);
}
}
void SYCLAllocator::GetStats(AllocatorStats* stats) {
mutex_lock lock(mu_);
*stats = stats_;
}
void SYCLAllocator::ClearStats() override {
mutex_lock l(mu_);
stats_.num_allocs = 0;
stats_.max_bytes_in_use = stats_.bytes_in_use;
stats_.max_alloc_size = 0;
}
size_t SYCLAllocator::RequestedSize(const void* ptr) const {
mutex_lock lock(mu_);
if (!sycl_device_) {
return 0;
}
const auto& buffer = sycl_device_->get_sycl_buffer(ptr);
return buffer.get_size();
}
} // namespace tensorflow
#endif // TENSORFLOW_USE_SYCL

View File

@ -1,75 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if !TENSORFLOW_USE_SYCL
#error This file must only be included when building TensorFlow with SYCL support
#endif
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class SYCLAllocator : public Allocator {
public:
SYCLAllocator(Eigen::QueueInterface* queue);
~SYCLAllocator() override;
string Name() override;
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
bool ShouldAllocateEmptyTensors() const final { return true; }
void Synchronize() {
mutex_lock lock(mu_);
if (sycl_device_) {
sycl_device_->synchronize();
}
}
bool Ok() const { return sycl_device_ && sycl_device_->ok(); }
void GetStats(AllocatorStats* stats) override;
void ClearStats() override;
// The SYCL buffers keep track of their size, so we already have tracking.
bool TracksAllocationSizes() const override { return true; }
// Get the size of the corresponding SYCL buffer.
// Implementing this also provides an implementation of
// AllocatedSize(void* ptr) by default.
size_t RequestedSize(const void* ptr) const override;
Eigen::SyclDevice* getSyclDevice() { return sycl_device_; }
// Clear the SYCL device used by the Allocator
void ClearSYCLDevice() {
mutex_lock lock(mu_);
if (sycl_device_) {
delete sycl_device_;
sycl_device_ = nullptr;
}
}
private:
mutable mutex mu_;
Eigen::SyclDevice* sycl_device_ TF_GUARDED_BY(mu_); // owned
AllocatorStats stats_ TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_

View File

@ -1,94 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if TENSORFLOW_USE_SYCL
#include "tensorflow/core/common_runtime/sycl/sycl_device.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
SYCLDevice::~SYCLDevice() {}
void SYCLDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
assert(context);
// When ThreadScape profiling is off (which is the default), constructing the
// following code is simple enough that its overhead is negligible.
tracing::ScopedRegion region(tracing::EventCategory::kCompute,
op_kernel->name());
op_kernel->Compute(context);
}
Allocator* SYCLDevice::GetAllocator(AllocatorAttributes attr) {
if (attr.on_host())
return cpu_allocator_;
else
return sycl_allocator_;
}
Status SYCLDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
AllocatorAttributes attr;
attr.set_on_host(true);
Allocator* host_alloc = GetAllocator(attr);
Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(host_alloc, tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
}
Status status;
if (alloc_attrs.on_host()) {
*tensor = parsed;
} else {
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
// If the tensor is not initialized, we likely ran out of memory.
if (!copy.IsInitialized()) {
return errors::ResourceExhausted(
"OOM when allocating tensor of shape ", parsed.shape().DebugString(),
" and type ", DataTypeString(parsed.dtype()));
}
device_context_->CopyCPUTensorToDevice(
&parsed, this, &copy, [&status](const Status& s) { status = s; });
*tensor = copy;
}
return status;
}
Status SYCLDevice::TryGetDeviceContext(DeviceContext** out_context) {
device_context_->Ref();
*out_context = device_context_;
return Status::OK();
}
Status SYCLDevice::Sync() {
sycl_allocator_->Synchronize();
if (sycl_allocator_->Ok()) {
return Status::OK();
} else {
return errors::Internal("Unknown error detected on device ", name());
}
}
} // namespace tensorflow
#endif // TENSORFLOW_USE_SYCL

View File

@ -1,231 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if !TENSORFLOW_USE_SYCL
#error This file must only be included when building TensorFlow with SYCL support
#endif
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/common_runtime/sycl/sycl_allocator.h"
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class GSYCLInterface {
std::vector<Eigen::QueueInterface*> m_queue_interface_; // owned
std::vector<Allocator*> m_cpu_allocator_; // not owned
std::vector<SYCLAllocator*> m_sycl_allocator_; // owned
std::vector<SYCLDeviceContext*> m_sycl_context_; // ref counted
GSYCLInterface() {
bool found_device = false;
auto device_list = Eigen::get_sycl_supported_devices();
// Obtain list of supported devices from Eigen
for (const auto& device : device_list) {
if (device.is_gpu()) {
// returns first found GPU
AddDevice(device);
found_device = true;
}
}
if (!found_device) {
// Currently Intel GPU is not supported
LOG(WARNING) << "No OpenCL GPU found that is supported by "
<< "ComputeCpp/triSYCL, trying OpenCL CPU";
}
for (const auto& device : device_list) {
if (device.is_cpu()) {
// returns first found CPU
AddDevice(device);
found_device = true;
}
}
if (!found_device) {
LOG(WARNING) << "No OpenCL CPU found that is supported by "
<< "ComputeCpp/triSYCL, checking for host sycl device";
}
for (const auto& device : device_list) {
// triSYCL only supports the host device for now
if (device.is_host()) {
LOG(WARNING) << "Found SYCL host device";
AddDevice(device);
found_device = true;
}
}
if (!found_device) {
// Currently Intel GPU is not supported
LOG(FATAL) << "No SYCL host and no OpenCL GPU nor CPU"
<< " supported by ComputeCPP/triSYCL was found";
} else {
LOG(INFO) << "Found following OpenCL devices:";
for (int i = 0; i < device_list.size(); i++) {
LOG(INFO) << GetShortDeviceDescription(i);
}
}
}
~GSYCLInterface() {
m_cpu_allocator_.clear();
for (auto p : m_sycl_allocator_) {
p->Synchronize();
p->ClearSYCLDevice();
// Cannot delete the Allocator instances, as the Allocator lifetime
// needs to exceed any Tensor created by it. There is no way of
// knowing when all Tensors have been deallocated, as they are
// RefCounted and wait until all instances of a Tensor have been
// destroyed before calling Allocator.Deallocate. This could happen at
// program exit, which can set up a race condition between destroying
// Tensors and Allocators when the program is cleaning up.
}
m_sycl_allocator_.clear();
for (auto p : m_sycl_context_) {
p->Unref();
}
m_sycl_context_.clear();
for (auto p : m_queue_interface_) {
p->deallocate_all();
delete p;
}
m_queue_interface_.clear();
}
void AddDevice(const cl::sycl::device& d) {
m_queue_interface_.push_back(new Eigen::QueueInterface(d));
m_cpu_allocator_.push_back(cpu_allocator());
m_sycl_allocator_.push_back(new SYCLAllocator(m_queue_interface_.back()));
m_sycl_context_.push_back(new SYCLDeviceContext());
}
public:
static const GSYCLInterface* instance() {
// c++11 guarantees that this will be constructed in a thread safe way
static const GSYCLInterface instance;
return &instance;
}
Eigen::QueueInterface* GetQueueInterface(size_t i = 0) const {
if (!m_queue_interface_.empty()) {
return m_queue_interface_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
SYCLAllocator* GetSYCLAllocator(size_t i = 0) const {
if (!m_sycl_allocator_.empty()) {
return m_sycl_allocator_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
Allocator* GetCPUAllocator(size_t i = 0) const {
if (!m_cpu_allocator_.empty()) {
return m_cpu_allocator_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
SYCLDeviceContext* GetSYCLContext(size_t i = 0) const {
if (!m_sycl_context_.empty()) {
return m_sycl_context_[i];
} else {
std::cerr << "No cl::sycl::device has been added" << std::endl;
return nullptr;
}
}
string GetShortDeviceDescription(int device_id = 0) const {
Eigen::QueueInterface* queue_ptr = GetQueueInterface(device_id);
if (!queue_ptr) {
LOG(ERROR)
<< "Device name cannot be given after Eigen QueueInterface destroyed";
return "";
}
auto device = queue_ptr->sycl_queue().get_device();
auto name = device.get_info<cl::sycl::info::device::name>();
auto vendor = device.get_info<cl::sycl::info::device::vendor>();
auto profile = device.get_info<cl::sycl::info::device::profile>();
std::string type;
if (device.is_host()) {
type = "Host";
} else if (device.is_cpu()) {
type = "CPU";
} else if (device.is_gpu()) {
type = "GPU";
} else if (device.is_accelerator()) {
type = "Accelerator";
} else {
type = "Unknown";
}
return strings::StrCat(
"id: ", device_id, ", type: ", type, ", name: ", name.c_str(),
", vendor: ", vendor.c_str(), ", profile: ", profile.c_str());
}
};
class SYCLDevice : public LocalDevice {
public:
SYCLDevice(const SessionOptions& options, const string& name,
Bytes memory_limit, const DeviceLocality& locality,
const string& physical_device_desc, SYCLAllocator* sycl_allocator,
Allocator* cpu_allocator, SYCLDeviceContext* ctx)
: LocalDevice(options, Device::BuildDeviceAttributes(
name, DEVICE_SYCL, memory_limit, locality,
physical_device_desc)),
cpu_allocator_(cpu_allocator),
sycl_allocator_(sycl_allocator),
device_context_(ctx) {
set_eigen_sycl_device(sycl_allocator->getSyclDevice());
}
~SYCLDevice() override;
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override;
Status TryGetDeviceContext(DeviceContext** out_context) override;
Status Sync() override;
private:
Allocator* cpu_allocator_; // not owned
SYCLAllocator* sycl_allocator_; // not owned
SYCLDeviceContext* device_context_; // not owned
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_

View File

@ -1,181 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if TENSORFLOW_USE_SYCL
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
namespace tensorflow {
void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor *cpu_tensor,
Device *device,
Tensor *device_tensor,
StatusCallback done) const {
const int64 total_bytes = cpu_tensor->TotalBytes();
if (total_bytes > 0) {
const void *src_ptr = DMAHelper::base(cpu_tensor);
void *dst_ptr = DMAHelper::base(device_tensor);
switch (cpu_tensor->dtype()) {
case DT_FLOAT:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
total_bytes);
break;
case DT_DOUBLE:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<double *>(dst_ptr),
static_cast<const double *>(src_ptr), total_bytes);
break;
case DT_INT32:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
total_bytes);
break;
case DT_INT64:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
total_bytes);
break;
case DT_HALF:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<Eigen::half *>(dst_ptr),
static_cast<const Eigen::half *>(src_ptr), total_bytes);
break;
case DT_COMPLEX64:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<std::complex<float> *>(dst_ptr),
static_cast<const std::complex<float> *>(src_ptr), total_bytes);
break;
case DT_COMPLEX128:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<std::complex<double> *>(dst_ptr),
static_cast<const std::complex<double> *>(src_ptr), total_bytes);
break;
case DT_INT8:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
total_bytes);
break;
case DT_INT16:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
total_bytes);
break;
case DT_UINT8:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
total_bytes);
break;
case DT_UINT16:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<uint16 *>(dst_ptr),
static_cast<const uint16 *>(src_ptr), total_bytes);
break;
case DT_BOOL:
device->eigen_sycl_device()->memcpyHostToDevice(
static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
total_bytes);
break;
default:
assert(false && "unsupported type");
}
}
device->eigen_sycl_device()->synchronize();
done(Status::OK());
}
void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor,
StringPiece edge_name,
Device *device,
Tensor *cpu_tensor,
StatusCallback done) {
const int64 total_bytes = device_tensor->TotalBytes();
if (total_bytes > 0) {
const void *src_ptr = DMAHelper::base(device_tensor);
void *dst_ptr = DMAHelper::base(cpu_tensor);
switch (device_tensor->dtype()) {
case DT_FLOAT:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
total_bytes);
break;
case DT_DOUBLE:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<double *>(dst_ptr),
static_cast<const double *>(src_ptr), total_bytes);
break;
case DT_INT32:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
total_bytes);
break;
case DT_INT64:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
total_bytes);
break;
case DT_HALF:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<Eigen::half *>(dst_ptr),
static_cast<const Eigen::half *>(src_ptr), total_bytes);
break;
case DT_COMPLEX64:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<std::complex<float> *>(dst_ptr),
static_cast<const std::complex<float> *>(src_ptr), total_bytes);
break;
case DT_COMPLEX128:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<std::complex<double> *>(dst_ptr),
static_cast<const std::complex<double> *>(src_ptr), total_bytes);
break;
case DT_INT8:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
total_bytes);
break;
case DT_INT16:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
total_bytes);
break;
case DT_UINT8:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
total_bytes);
break;
case DT_UINT16:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<uint16 *>(dst_ptr),
static_cast<const uint16 *>(src_ptr), total_bytes);
break;
case DT_BOOL:
device->eigen_sycl_device()->memcpyDeviceToHost(
static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
total_bytes);
break;
default:
assert(false && "unsupported type");
}
}
device->eigen_sycl_device()->synchronize();
done(Status::OK());
}
} // namespace tensorflow
#endif // TENSORFLOW_USE_SYCL

View File

@ -1,45 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if !TENSORFLOW_USE_SYCL
#error This file must only be included when building TensorFlow with SYCL support
#endif
#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device_base.h"
namespace tensorflow {
class SYCLDeviceContext : public DeviceContext {
public:
SYCLDeviceContext() {}
~SYCLDeviceContext() override {}
void CopyCPUTensorToDevice(const Tensor *cpu_tensor, Device *device,
Tensor *device_tensor,
StatusCallback done) const override;
void CopyDeviceTensorToCPU(const Tensor *device_tensor, StringPiece edge_name,
Device *device, Tensor *cpu_tensor,
StatusCallback done) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_

View File

@ -1,57 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if TENSORFLOW_USE_SYCL
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/sycl/sycl_device.h"
#include "tensorflow/core/common_runtime/sycl/sycl_util.h"
namespace tensorflow {
class SYCLDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override {
return tensorflow::Status::OK();
}
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override {
auto syclInterface = GSYCLInterface::instance();
size_t n = 1;
auto iter = options.config.device_count().find("SYCL");
if (iter != options.config.device_count().end()) {
n = iter->second;
}
for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:SYCL:", i);
devices->push_back(new SYCLDevice(
options, name, Bytes(256 << 20), DeviceLocality(),
syclInterface->GetShortDeviceDescription(i),
syclInterface->GetSYCLAllocator(i), syclInterface->GetCPUAllocator(i),
syclInterface->GetSYCLContext(i)));
}
return Status::OK();
}
};
REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200);
} // namespace tensorflow
#endif // TENSORFLOW_USE_SYCL

View File

@ -1,80 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if !TENSORFLOW_USE_SYCL
#error This file must only be included when building TensorFlow with SYCL support
#endif
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
// For DMA helper
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
inline void const* GetBase(const Tensor* src) { return DMAHelper::base(src); }
inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
inline void SYCLmemcpy(Eigen::SyclDevice const& device,
Tensor const& src_tensor, Tensor* dst_tensor) {
const size_t size = src_tensor.TotalBytes();
void* dst_ptr = GetBase(dst_tensor);
void const* src_ptr = GetBase(&src_tensor);
#define COPY_WITH_TYPE(T) \
device.memcpy(dst_ptr, static_cast<T const*>(src_ptr), size);
switch (src_tensor.dtype()) {
case DT_COMPLEX128:
COPY_WITH_TYPE(cl::sycl::cl_ulong2);
break;
case DT_DOUBLE:
case DT_COMPLEX64:
case DT_INT64:
COPY_WITH_TYPE(cl::sycl::cl_ulong);
break;
case DT_FLOAT:
case DT_INT32:
case DT_QINT32:
COPY_WITH_TYPE(cl::sycl::cl_uint);
break;
case DT_INT16:
case DT_UINT16:
case DT_BFLOAT16:
case DT_QINT16:
case DT_QUINT16:
case DT_HALF:
COPY_WITH_TYPE(cl::sycl::cl_ushort);
break;
case DT_BOOL:
COPY_WITH_TYPE(bool);
break;
case DT_UINT8:
case DT_INT8:
case DT_QINT8:
case DT_QUINT8:
COPY_WITH_TYPE(cl::sycl::cl_uchar);
break;
default:
LOG(FATAL) << "Unknown data type " << src_tensor.dtype();
break;
}
#undef COPY_WITH_TYPE
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_

View File

@ -283,12 +283,10 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) {
DeleteDumpDir();
} else {
// CUDA and SYCL devices do not have an Identity op for strings
// The CUDA device does not have an Identity op for strings
LOG(ERROR) << "Error: " << s;
ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
(a_dev.device_type() == DEVICE_SYCL) ||
(b_dev.device_type() == DEVICE_GPU) ||
(b_dev.device_type() == DEVICE_SYCL));
(b_dev.device_type() == DEVICE_GPU));
ASSERT_FALSE(s.ok());
}
}

View File

@ -32,9 +32,6 @@ limitations under the License.
namespace Eigen {
struct ThreadPoolDevice;
#ifdef TENSORFLOW_USE_SYCL
struct SyclDevice;
#endif
} // end namespace Eigen
namespace stream_executor {
@ -176,9 +173,6 @@ class DeviceBase {
// Does not take ownership.
void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
#ifdef TENSORFLOW_USE_SYCL
void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; }
#endif
// Return the Allocator implementation to use based on the allocator
// attributes requested. See allocator.h for more details.
@ -210,12 +204,6 @@ class DeviceBase {
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
#ifdef TENSORFLOW_USE_SYCL
virtual const Eigen::SyclDevice* eigen_sycl_device() const {
CHECK(eigen_sycl_device_ != nullptr);
return eigen_sycl_device_;
}
#endif
// Caller owns the return value. The OpKernelContext calls this even
// for devices that do not implement an eigen_gpu_device. Overridden
@ -290,9 +278,6 @@ class DeviceBase {
GpuDeviceInfo* gpu_device_info_ = nullptr;
thread::ThreadPool* device_thread_pool_ = nullptr;
std::vector<Eigen::ThreadPoolDevice*> eigen_cpu_devices_;
#ifdef TENSORFLOW_USE_SYCL
Eigen::SyclDevice* eigen_sycl_device_ = nullptr;
#endif
};
// Methods to create and check for Symbolic execution devices.

View File

@ -114,10 +114,9 @@ OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
context->graph_def_version()));
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the
// Kernels executing on GPU tie very few resources on the CPU where the
// scheduler runs: we consider them as inexpensive.
expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
context->device_type() != DeviceType(DEVICE_SYCL);
expensive_ = context->device_type() != DeviceType(DEVICE_GPU);
}
OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
@ -141,10 +140,9 @@ OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
OP_REQUIRES_OK(context, CheckOpDeprecation(*props_->op_def,
context->graph_def_version()));
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the
// Kernels executing on GPU tie very few resources on the CPU where the
// scheduler runs: we consider them as inexpensive.
expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
context->device_type() != DeviceType(DEVICE_SYCL);
expensive_ = context->device_type() != DeviceType(DEVICE_GPU);
}
OpKernel::~OpKernel() {}
@ -1722,12 +1720,6 @@ const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
return eigen_gpu_device();
}
#ifdef TENSORFLOW_USE_SYCL
template <>
const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
return eigen_sycl_device();
}
#endif
void OpKernelConstruction::CtxFailure(const Status& s) {
VLOG(1) << s;

View File

@ -58,7 +58,6 @@ limitations under the License.
namespace Eigen {
struct ThreadPoolDevice;
struct GpuDevice;
struct SyclDevice;
} // end namespace Eigen
namespace tensorflow {
@ -1149,11 +1148,6 @@ class OpKernelContext {
const Eigen::GpuDevice& eigen_gpu_device() const {
return params_->eigen_gpu_device->device();
}
#ifdef TENSORFLOW_USE_SYCL
const Eigen::SyclDevice& eigen_sycl_device() const {
return *device()->eigen_sycl_device();
}
#endif
template <typename EigenDeviceType>
const EigenDeviceType& eigen_device() const;
@ -1336,10 +1330,6 @@ const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
template <>
const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
#ifdef TENSORFLOW_USE_SYCL
template <>
const Eigen::SyclDevice& OpKernelContext::eigen_device() const;
#endif
// Register your OpKernel by specifying the Op's name, the device the
// kernel runs on, any type attr constraints for this kernel, any

View File

@ -211,16 +211,4 @@ limitations under the License.
TF_CALL_COMPLEX_TYPES(m) \
TF_CALL_QUANTIZED_TYPES(m) TF_CALL_bool(m) TF_CALL_tstring(m)
#ifdef TENSORFLOW_SYCL_NO_DOUBLE
#define TF_CALL_SYCL_double(m)
#else // TENSORFLOW_SYCL_NO_DOUBLE
#define TF_CALL_SYCL_double(m) TF_CALL_double(m)
#endif // TENSORFLOW_SYCL_NO_DOUBLE
#ifdef __ANDROID_TYPES_SLIM__
#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m)
#else // __ANDROID_TYPES_SLIM__
#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m)
#endif // __ANDROID_TYPES_SLIM__
#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_

View File

@ -21,9 +21,6 @@ limitations under the License.
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
@ -74,16 +71,6 @@ struct proxy_type_pod<GPUDevice, 1> {
typedef ::tensorflow::int8 type;
};
#ifdef TENSORFLOW_USE_SYCL
template <>
struct proxy_type_pod<SYCLDevice, 8> {
typedef double type;
};
template <>
struct proxy_type_pod<SYCLDevice, 4> {
typedef float type;
};
#endif // TENSORFLOW_USE_SYCL
/// If POD we use proxy_type_pod, otherwise this maps to identity.
template <typename Device, typename T>
@ -101,10 +88,6 @@ struct proxy_type {
#define TF_CALL_GPU_PROXY_TYPES(m) \
TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m) \
TF_CALL_int8(m)
#ifdef TENSORFLOW_USE_SYCL
#define TF_CALL_SYCL_PROXY_TYPES(m) \
TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m)
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_

View File

@ -38,7 +38,6 @@ std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
const char* const DEVICE_DEFAULT = "DEFAULT";
const char* const DEVICE_CPU = "CPU";
const char* const DEVICE_GPU = "GPU";
const char* const DEVICE_SYCL = "SYCL";
const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
@ -46,9 +45,6 @@ const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
const std::string DeviceName<Eigen::SyclDevice>::value = DEVICE_SYCL;
#endif // TENSORFLOW_USE_SYCL
namespace {
string DataTypeStringInternal(DataType dtype) {

View File

@ -74,7 +74,6 @@ std::ostream& operator<<(std::ostream& os, const DeviceType& d);
TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT"
TF_EXPORT extern const char* const DEVICE_CPU; // "CPU"
TF_EXPORT extern const char* const DEVICE_GPU; // "GPU"
TF_EXPORT extern const char* const DEVICE_SYCL; // "SYCL"
TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM"
template <typename Device>
@ -93,12 +92,6 @@ struct DeviceName<Eigen::GpuDevice> {
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
template <>
struct DeviceName<Eigen::SyclDevice> {
static const std::string value;
};
#endif // TENSORFLOW_USE_SYCL
typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice;

View File

@ -26,7 +26,6 @@ namespace {
TEST(TypesTest, DeviceTypeName) {
EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU)));
EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU)));
EXPECT_EQ("SYCL", DeviceTypeString(DeviceType(DEVICE_SYCL)));
}
TEST(TypesTest, kDataTypeRefOffset) {

View File

@ -15,7 +15,6 @@ load(
"tf_kernel_library",
"tf_opts_nortti_if_lite_protos",
)
load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
load("//tensorflow/core/kernels/mlir_generated:build_defs.bzl", "if_mlir_generated_gpu_kernels_enabled")
# buildifier: disable=same-origin-load
@ -922,7 +921,7 @@ ARRAY_DEPS = [
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
] + if_sycl(["//tensorflow/core/common_runtime/sycl:sycl_runtime"])
]
tf_kernel_library(
name = "immutable_constant_op",
@ -1240,7 +1239,6 @@ tf_kernel_library(
"tile_functor_cpu_uint64.cc",
"tile_functor_cpu_uint8.cc",
"tile_functor_cpu_variant.cc",
"tile_functor_sycl.cc",
],
hdrs = ["tile_functor.h"],
gpu_srcs = [
@ -4206,7 +4204,7 @@ tf_kernel_library(
"maxpooling_op.h",
"pooling_ops_3d.h",
"pooling_ops_common.h",
] + if_sycl(["pooling_ops_3d_sycl.h"]),
],
gpu_srcs = [
"avgpooling_op.h",
"avgpooling_op_gpu.cu.cc",
@ -4872,7 +4870,7 @@ STATE_DEPS = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
] + if_sycl(["//tensorflow/core/common_runtime/sycl:sycl_runtime"])
]
tf_kernel_library(
name = "count_up_to_op",
@ -6393,7 +6391,6 @@ filegroup(
"unicode_script_op.cc",
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
"xsmm_*",
"cwise_ops_sycl_common.h",
"nextafter_op.cc",
] + ANDROID_TEXTUAL_HDRS,
) + [

View File

@ -28,9 +28,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
#define REGISTER_ADDN(type, dev) \
REGISTER_KERNEL_BUILDER( \
@ -67,21 +64,6 @@ REGISTER_KERNEL_BUILDER(
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
REGISTER_ADDN(float, SYCL);
REGISTER_ADDN(double, SYCL);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(
Name("AddN")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.HostMemory("inputs")
.HostMemory("sum"),
AddNOp<CPUDevice, int32, OpKernel, OpKernelConstruction, OpKernelContext>);
#endif // TENSORFLOW_USE_SYCL
#undef REGISTER_ADDN

View File

@ -23,9 +23,6 @@ limitations under the License.
typedef Eigen::ThreadPoolDevice CPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
namespace tensorflow {
@ -137,114 +134,6 @@ struct Add9Functor<CPUDevice, T> {
}
};
#ifdef TENSORFLOW_USE_SYCL
// Partial specializations for a SYCLDevice, that uses the Eigen implementation
// from AddNEigenImpl.
template <typename T>
struct Add2Functor<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1,
typename TTypes<T>::ConstFlat in2) {
Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2);
}
};
template <typename T>
struct Add3Functor<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1,
typename TTypes<T>::ConstFlat in2,
typename TTypes<T>::ConstFlat in3) {
Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3);
}
};
template <typename T>
struct Add4Functor<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1,
typename TTypes<T>::ConstFlat in2,
typename TTypes<T>::ConstFlat in3,
typename TTypes<T>::ConstFlat in4) {
Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4);
}
};
template <typename T>
struct Add5Functor<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1,
typename TTypes<T>::ConstFlat in2,
typename TTypes<T>::ConstFlat in3,
typename TTypes<T>::ConstFlat in4,
typename TTypes<T>::ConstFlat in5) {
Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
}
};
template <typename T>
struct Add6Functor<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1,
typename TTypes<T>::ConstFlat in2,
typename TTypes<T>::ConstFlat in3,
typename TTypes<T>::ConstFlat in4,
typename TTypes<T>::ConstFlat in5,
typename TTypes<T>::ConstFlat in6) {
Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
}
};
template <typename T>
struct Add7Functor<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1,
typename TTypes<T>::ConstFlat in2,
typename TTypes<T>::ConstFlat in3,
typename TTypes<T>::ConstFlat in4,
typename TTypes<T>::ConstFlat in5,
typename TTypes<T>::ConstFlat in6,
typename TTypes<T>::ConstFlat in7) {
Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
in7);
}
};
template <typename T>
struct Add8Functor<SYCLDevice, T> {
void operator()(
const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
in7, in8);
}
};
template <typename T>
struct Add8pFunctor<SYCLDevice, T> {
void operator()(
const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
in7, in8);
}
};
template <typename T>
struct Add9Functor<SYCLDevice, T> {
void operator()(
const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
typename TTypes<T>::ConstFlat in9) {
Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
in7, in8, in9);
}
};
#endif // TENSORFLOW_USE_SYCL
} // namespace functor

View File

@ -50,9 +50,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
namespace {
@ -632,48 +629,6 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
template <typename Scalar>
struct ParallelMatMulKernelSYCL {
static void Run(const OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
auto Tx = in_x.tensor<Scalar, 3>();
auto Ty = in_y.tensor<Scalar, 3>();
auto Tz = out->tensor<Scalar, 3>();
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
auto d = context->eigen_sycl_device();
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto x = Tx.template chip<0>(x_batch_index);
auto y = Ty.template chip<0>(y_batch_index);
auto z = Tz.template chip<0>(i);
z.device(d) = x.contract(y, contract_pairs);
}
}
};
template <typename Scalar>
struct LaunchBatchMatMul<SYCLDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
// Number of matrix multiplies i.e. size of the batch.
const int64 batch_size = bcast.output_batch_size();
ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
trans_x, trans_y, bcast, out, 0,
batch_size);
}
};
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename Scalar>
class BaseBatchMatMulOp : public OpKernel {
@ -826,15 +781,6 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Scalar> {
Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<GPUDevice, TYPE>)
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_BATCH_MATMUL_SYCL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BatchMatMulOp<SYCLDevice, TYPE>); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BatchMatMulV2Op<SYCLDevice, TYPE>)
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_

View File

@ -34,8 +34,4 @@ TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
TF_CALL_float(REGISTER_BATCH_MATMUL_SYCL);
TF_CALL_double(REGISTER_BATCH_MATMUL_SYCL);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -28,9 +28,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class BatchNormOp : public OpKernel {
@ -208,17 +205,6 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
.Device(DEVICE_SYCL) \
.TypeConstraint<T>("T"), \
BatchNormOp<SYCLDevice, T>);
TF_CALL_float(REGISTER_KERNEL);
TF_CALL_double(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
@ -267,17 +253,5 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
.Device(DEVICE_SYCL) \
.TypeConstraint<T>("T"), \
BatchNormGradOp<SYCLDevice, T>);
TF_CALL_float(REGISTER_KERNEL);
TF_CALL_double(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -29,9 +29,6 @@ namespace concat_split_util {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
// Concatenates 'inputs' into a single tensor along the zeroth dimension.
// Requires that all elements of 'inputs' have element type T. Writes to

View File

@ -145,22 +145,6 @@ REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
.HostMemory("r0"),
BCastArgsOp<int64>);
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0"),
BCastArgsOp<int32>);
REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
.Device(DEVICE_SYCL)
.TypeConstraint<int64>("T")
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0"),
BCastArgsOp<int32>);
#endif
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.Device(DEVICE_CPU)
@ -195,22 +179,4 @@ REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.HostMemory("r1"),
BCastGradArgsOp<int64>);
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0")
.HostMemory("r1"),
BCastGradArgsOp<int32>);
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.Device(DEVICE_SYCL)
.TypeConstraint<int64>("T")
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0")
.HostMemory("r1"),
BCastGradArgsOp<int64>);
#endif
} // end namespace tensorflow

View File

@ -39,9 +39,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
namespace {
@ -216,20 +213,6 @@ class BiasOp : public BinaryOp<T> {
TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("BiasAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
BiasOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("BiasAddV1").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
BiasOp<SYCLDevice, type>);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class BiasGradOp : public OpKernel {
@ -308,17 +291,6 @@ class BiasGradOp : public OpKernel {
TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("BiasAddGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
BiasGradOp<SYCLDevice, type>);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>

View File

@ -34,9 +34,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
#define CURRY_TYPES2(FN, arg0) \
FN(arg0, bool); \
@ -253,50 +250,6 @@ REGISTER_CAST_GPU(bfloat16, float);
#undef REGISTER_CAST_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
class SyclCastOp : public CastOpBase {
public:
explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
OP_REQUIRES_OK(ctx, Prepare());
}
private:
Status Prepare() {
if (external_src_dtype_ == external_dst_dtype_) {
work_ = nullptr; // Identity
return Status::OK();
}
if (src_dtype_ == DT_BOOL) {
work_ = GetSyclCastFromBool(dst_dtype_);
} else if (src_dtype_ == DT_INT32) {
work_ = GetSyclCastFromInt32(dst_dtype_);
} else if (src_dtype_ == DT_INT64) {
work_ = GetSyclCastFromInt64(dst_dtype_);
} else if (src_dtype_ == DT_FLOAT) {
work_ = GetSyclCastFromFloat(dst_dtype_);
} else if (src_dtype_ == DT_DOUBLE) {
work_ = GetSyclCastFromDouble(dst_dtype_);
}
return work_ == nullptr ? Unimplemented() : Status::OK();
}
};
#define REGISTER_CAST_SYCL(srctype, dsttype) \
REGISTER_KERNEL_BUILDER(Name("Cast") \
.TypeConstraint<srctype>("SrcT") \
.TypeConstraint<dsttype>("DstT") \
.Device(DEVICE_SYCL), \
SyclCastOp)
CURRY_TYPES2(REGISTER_CAST_SYCL, bool);
CURRY_TYPES2(REGISTER_CAST_SYCL, int32);
CURRY_TYPES2(REGISTER_CAST_SYCL, int64);
CURRY_TYPES2(REGISTER_CAST_SYCL, float);
CURRY_TYPES2(REGISTER_CAST_SYCL, double);
#undef REGISTER_CAST_SYCL
#endif // TENSORFLOW_USE_SYCL
#undef CURRY_TYPES2

View File

@ -27,9 +27,6 @@ namespace functor {
CAST_FUNCTORS(Eigen::ThreadPoolDevice);
#ifdef TENSORFLOW_USE_SYCL
CAST_FUNCTORS(Eigen::SyclDevice);
#endif // TENSORFLOW_USE_SYCL
} // namespace functor
@ -134,27 +131,6 @@ CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
CastFunctorType GetSyclCastFromBool(DataType dst_dtype);
CastFunctorType GetSyclCastFromUint8(DataType dst_dtype);
CastFunctorType GetSyclCastFromUint16(DataType dst_dtype);
CastFunctorType GetSyclCastFromUint32(DataType dst_dtype);
CastFunctorType GetSyclCastFromUint64(DataType dst_dtype);
CastFunctorType GetSyclCastFromInt16(DataType dst_dtype);
CastFunctorType GetSyclCastFromInt32(DataType dst_dtype);
CastFunctorType GetSyclCastFromInt64(DataType dst_dtype);
CastFunctorType GetSyclCastFromFloat(DataType dst_dtype);
CastFunctorType GetSyclCastFromDouble(DataType dst_dtype);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromBool(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromBool(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, bool);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromDouble(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, double);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -35,12 +35,5 @@ CastFunctorType GetGpuCastFromFloat(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, float);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromInt16(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int16);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromInt32(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int32);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromInt64(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int64);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromInt8(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int8);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromUint16(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint16);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromUint32(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromUint64(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint64);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,12 +33,5 @@ CastFunctorType GetGpuCastFromUint8(DataType dst_dtype) {
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
CastFunctorType GetSyclCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint8);
return nullptr;
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -138,9 +138,6 @@ static void BM_gpu_float_int64(int iters, int num) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
test::Benchmark("gpu", Cast<float, int64>(num)).Run(iters);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
test::Benchmark("sycl", Cast<float, int64>(num)).Run(iters);
#endif // TENSORFLOW_USE_SYCL
}
BENCHMARK(BM_gpu_float_int64)->Arg(64 << 10)->Arg(32 << 20);
@ -161,9 +158,6 @@ static void BM_gpu_bool_float(int iters, int num) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
test::Benchmark("gpu", Cast<bool, float>(num)).Run(iters);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
test::Benchmark("sycl", Cast<bool, float>(num)).Run(iters);
#endif // TENSORFLOW_USE_SYCL
}
BENCHMARK(BM_gpu_bool_float)->Arg(64 << 10)->Arg(32 << 20);

View File

@ -73,14 +73,6 @@ TF_CALL_GPU_ALL_TYPES(REGISTER);
#undef REGISTER
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
void ConcatSYCL(
const Eigen::SyclDevice& d,
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs,
typename TTypes<T, 2>::Matrix* output);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_

View File

@ -127,24 +127,4 @@ REGISTER(tstring);
// !defined(SUPPORT_SELECTIVE_REGISTRATION) &&
// !defined(__ANDROID_TYPES_FULL__)
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
void ConcatSYCL(
const Eigen::SyclDevice& d,
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs,
typename TTypes<T, 2>::Matrix* output) {
ConcatSYCLImpl<T>(d, inputs, sizeof(T) /* cost_per_unit */, MemCpyCopier<T>(),
output);
}
#define REGISTER_SYCL(T) \
template void ConcatSYCL<T>( \
const Eigen::SyclDevice&, \
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&, \
typename TTypes<T, 2>::Matrix* output);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL)
#undef REGISTER_SYCL
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -130,41 +130,6 @@ void ConcatCPUImpl(
cost_per_unit, work);
}
#ifdef TENSORFLOW_USE_SYCL
template <typename T, typename ElementCopier>
void ConcatSYCLImpl(
const Eigen::SyclDevice& d,
const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
inputs,
int64 cost_per_unit, ElementCopier copier,
typename TTypes<T, 2>::Matrix* output) {
size_t num_inputs = inputs.size();
std::vector<ptrdiff_t> sizes;
sizes.reserve(num_inputs);
int64 row_size = 0;
for (const auto& input : inputs) {
sizes.push_back(input->dimension(1));
row_size += sizes.back();
}
T* out = &(*output)(0, 0);
std::vector<const T*> inp;
inp.reserve(num_inputs);
for (const auto& input : inputs) {
inp.push_back(&(*input)(0, 0));
}
const int64 dim0 = output->dimension(0);
for (int64 i = 0; i < dim0; ++i) {
for (int64 j = 0; j < num_inputs; ++j) {
auto size = sizes[j];
d.memcpy(out, inp[j], size * sizeof(T));
out += size;
inp[j] += size;
}
}
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_

View File

@ -35,9 +35,6 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
@ -168,12 +165,6 @@ class ConcatBaseOp : public OpKernel {
return;
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
if (std::is_same<Device, SYCLDevice>::value) {
ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat);
return;
}
#endif // TENSORFLOW_USE_SYCL
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
}
}
@ -251,38 +242,6 @@ REGISTER_KERNEL_BUILDER(Name("ConcatV2")
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL(type) \
REGISTER_KERNEL_BUILDER(Name("Concat") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.HostMemory("concat_dim"), \
ConcatOp<SYCLDevice, type>) \
REGISTER_KERNEL_BUILDER(Name("ConcatV2") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.HostMemory("axis"), \
ConcatV2Op<SYCLDevice, type>)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
REGISTER_KERNEL_BUILDER(Name("Concat")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.HostMemory("concat_dim")
.HostMemory("values")
.HostMemory("output"),
ConcatOp<CPUDevice, int32>);
REGISTER_KERNEL_BUILDER(Name("ConcatV2")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.HostMemory("values")
.HostMemory("axis")
.HostMemory("output"),
ConcatV2Op<CPUDevice, int32>);
#undef REGISTER_SYCL
#endif // TENSORFLOW_USE_SYCL
class ConcatOffsetOp : public OpKernel {
public:
@ -370,12 +329,4 @@ REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
.HostMemory("offset"),
ConcatOffsetOp);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
.Device(DEVICE_SYCL)
.HostMemory("concat_dim")
.HostMemory("shape")
.HostMemory("offset"),
ConcatOffsetOp);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -39,9 +39,6 @@ limitations under the License.
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/macros.h"
#ifdef TENSORFLOW_USE_SYCL
#include "tensorflow/core/common_runtime/sycl/sycl_util.h"
#endif // TENSORFLOW_USE_SYCL
namespace tensorflow {
@ -127,33 +124,9 @@ REGISTER_KERNEL(GPU, Variant);
#undef REGISTER_KERNEL
#endif
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(D, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("Const").Device(DEVICE_##D).TypeConstraint<TYPE>("dtype"), \
ConstantOp);
REGISTER_SYCL_KERNEL(SYCL, float);
REGISTER_SYCL_KERNEL(SYCL, double);
REGISTER_SYCL_KERNEL(SYCL, uint8);
REGISTER_SYCL_KERNEL(SYCL, int8);
REGISTER_SYCL_KERNEL(SYCL, qint8);
REGISTER_SYCL_KERNEL(SYCL, uint16);
REGISTER_SYCL_KERNEL(SYCL, int16);
REGISTER_SYCL_KERNEL(SYCL, qint16);
REGISTER_SYCL_KERNEL(SYCL, quint16);
REGISTER_SYCL_KERNEL(SYCL, uint32);
REGISTER_SYCL_KERNEL(SYCL, qint32);
REGISTER_SYCL_KERNEL(SYCL, int64);
REGISTER_SYCL_KERNEL(SYCL, uint64);
REGISTER_SYCL_KERNEL(SYCL, bool);
#undef REGISTER_SYCL_KERNEL
#endif
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T, typename Index>
class FillOp : public OpKernel {
@ -216,25 +189,6 @@ REGISTER_KERNEL(CPU, qint8);
REGISTER_KERNEL(CPU, qint16);
#undef REGISTER_CPU_KERNEL
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL(SYCL, float);
REGISTER_KERNEL(SYCL, double);
REGISTER_KERNEL(SYCL, uint8);
REGISTER_KERNEL(SYCL, int8);
REGISTER_KERNEL(SYCL, uint16);
REGISTER_KERNEL(SYCL, int16);
REGISTER_KERNEL(SYCL, int64);
REGISTER_KERNEL_BUILDER(Name("Fill")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("index_type")
.HostMemory("dims")
.HostMemory("value")
.HostMemory("output"),
FillOp<CPUDevice, int32, int32>);
#undef REGISTER_KERNEL_SYCL
#endif // TENSORFLOW_USE_SYCL
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
@ -309,17 +263,6 @@ TF_CALL_POD_STRING_TYPES(REGISTER_CPU);
REGISTER_CPU(Variant);
#undef REGISTER_CPU
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL(bool, SYCL);
REGISTER_KERNEL(float, SYCL);
REGISTER_KERNEL(double, SYCL);
REGISTER_KERNEL(int64, SYCL);
REGISTER_KERNEL_BUILDER(Name("ZerosLike")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.HostMemory("y"),
ZerosLikeOp<CPUDevice, int32>);
#endif // TENSORFLOW_USE_SYCL
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
@ -365,15 +308,6 @@ class OnesLikeOp : public OpKernel {
TF_CALL_POD_TYPES(REGISTER_CPU);
#undef REGISTER_CPU
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL(float, SYCL);
REGISTER_KERNEL(bool, SYCL);
REGISTER_KERNEL_BUILDER(Name("OnesLike")
.Device(DEVICE_SYCL)
.TypeConstraint<int32>("T")
.HostMemory("y"),
OnesLikeOp<CPUDevice, int32>);
#endif // TENSORFLOW_USE_SYCL
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)

View File

@ -156,57 +156,6 @@ REGISTER_GPU_HOST_KERNEL(ResourceHandle);
#undef REGISTER_GPU_HOST_KERNEL
#undef REGISTER_GPU_HOST_REF_KERNEL
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_SYCL) \
.HostMemory("pred") \
.TypeConstraint<type>("T"), \
SwitchOp)
TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_SWITCH);
#define REGISTER_SYCL_REF_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
.Device(DEVICE_SYCL) \
.HostMemory("pred") \
.TypeConstraint<type>("T"), \
SwitchOp)
TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_SWITCH);
#undef REGISTER_SYCL_SWITCH
#undef REGISTER_SYCL_REF_SWITCH
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_SYCL) \
.HostMemory("data") \
.HostMemory("pred") \
.HostMemory("output_false") \
.HostMemory("output_true") \
.TypeConstraint<type>("T"), \
SwitchOp)
REGISTER_SYCL_HOST_KERNEL(bool);
REGISTER_SYCL_HOST_KERNEL(tstring);
REGISTER_SYCL_HOST_KERNEL(int32);
#define REGISTER_SYCL_HOST_REF_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
.Device(DEVICE_SYCL) \
.HostMemory("data") \
.HostMemory("pred") \
.HostMemory("output_false") \
.HostMemory("output_true") \
.TypeConstraint<type>("T"), \
SwitchOp)
REGISTER_SYCL_HOST_REF_KERNEL(int32);
REGISTER_SYCL_HOST_REF_KERNEL(bool);
REGISTER_SYCL_HOST_REF_KERNEL(tstring);
#undef REGISTER_SYCL_HOST_KERNEL
#undef REGISTER_SYCL_HOST_REF_KERNEL
#endif // TENSORFLOW_USE_SYCL
class RefSelectOp : public OpKernel {
public:
@ -316,28 +265,6 @@ TF_CALL_variant(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Merge") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.HostMemory("value_index"), \
MergeOp);
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
#define REGISTER_SYCL_REF_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("RefMerge") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \
.HostMemory("value_index"), \
MergeOp);
REGISTER_SYCL_REF_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
#undef REGISTER_SYCL_KERNEL
#undef REGISTER_SYCL_REF_KERNEL
#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@ -364,29 +291,6 @@ REGISTER_GPU_HOST_KERNEL(ResourceHandle);
#undef REGISTER_GPU_HOST_KERNEL
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Merge") \
.Device(DEVICE_SYCL) \
.HostMemory("inputs") \
.HostMemory("output") \
.HostMemory("value_index") \
.TypeConstraint<type>("T"), \
MergeOp); \
REGISTER_KERNEL_BUILDER(Name("RefMerge") \
.Device(DEVICE_SYCL) \
.HostMemory("inputs") \
.HostMemory("output") \
.HostMemory("value_index") \
.TypeConstraint<type>("T"), \
MergeOp)
REGISTER_SYCL_HOST_KERNEL(int32);
REGISTER_SYCL_HOST_KERNEL(tstring);
REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
#undef REGISTER_SYCL_HOST_KERNEL
#endif // TENSORFLOW_USE_SYCL
void EnterOp::Compute(OpKernelContext* context) {
if (IsRefType(context->input_dtype(0))) {
@ -416,46 +320,6 @@ TF_CALL_variant(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Enter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
#define REGISTER_SYCL_REF_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("RefEnter").Device(DEVICE_SYCL).TypeConstraint<type>("T"), EnterOp)
REGISTER_SYCL_REF_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL);
#undef REGISTER_SYCL_KERNEL
#undef REGISTER_SYCL_REF_KERNEL
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Enter") \
.Device(DEVICE_SYCL) \
.HostMemory("data") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
EnterOp)
#define REGISTER_SYCL_HOST_REF_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("RefEnter") \
.Device(DEVICE_SYCL) \
.HostMemory("data") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
EnterOp)
REGISTER_SYCL_HOST_KERNEL(int32);
REGISTER_SYCL_HOST_REF_KERNEL(int32);
REGISTER_SYCL_HOST_KERNEL(tstring);
REGISTER_SYCL_HOST_REF_KERNEL(tstring);
REGISTER_SYCL_HOST_KERNEL(ResourceHandle);
#undef REGISTER_SYCL_HOST_KERNEL
#undef REGISTER_SYCL_HOST_REF_KERNEL
#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@ -513,36 +377,6 @@ TF_CALL_variant(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Exit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp); \
REGISTER_KERNEL_BUILDER( \
Name("RefExit").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ExitOp);
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
#undef REGISTER_SYCL_REF_KERNEL
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Exit") \
.Device(DEVICE_SYCL) \
.HostMemory("data") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
ExitOp); \
REGISTER_KERNEL_BUILDER(Name("RefExit") \
.Device(DEVICE_SYCL) \
.HostMemory("data") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
ExitOp)
REGISTER_SYCL_HOST_KERNEL(int32);
REGISTER_SYCL_HOST_KERNEL(tstring);
#undef REGISTER_SYCL_HOST_KERNEL
#endif // TENSORFLOW_USE_SYCL
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@ -619,37 +453,6 @@ REGISTER_GPU_HOST_KERNEL(ResourceHandle);
#undef REGISTER_GPU_HOST_KERNEL
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
NextIterationOp); \
REGISTER_KERNEL_BUILDER( \
Name("RefNextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
NextIterationOp)
REGISTER_SYCL_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
#define REGISTER_SYCL_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("NextIteration") \
.Device(DEVICE_SYCL) \
.HostMemory("data") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
NextIterationOp); \
REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
.Device(DEVICE_SYCL) \
.HostMemory("data") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
NextIterationOp)
REGISTER_SYCL_HOST_KERNEL(int32);
REGISTER_SYCL_HOST_KERNEL(tstring);
#undef REGISTER_SYCL_HOST_KERNEL
#endif // TENSORFLOW_USE_SYCL
LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
LoopCondOp::~LoopCondOp() = default;

View File

@ -39,13 +39,4 @@ REGISTER_KERNEL_BUILDER(Name("Abs")
#endif
#endif
#if TENSORFLOW_USE_SYCL
REGISTER3(UnaryOp, SYCL, "Abs", functor::abs, float, double, int64);
REGISTER_KERNEL_BUILDER(Name("Abs")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.TypeConstraint<int32>("T"),
UnaryOp<CPUDevice, functor::abs<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -22,7 +22,4 @@ REGISTER2(UnaryOp, CPU, "Acos", functor::acos, float, double);
REGISTER2(UnaryOp, GPU, "Acos", functor::acos, float, double);
#endif
#if TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Acos", functor::acos, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -20,9 +20,6 @@ namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Acosh", functor::acosh, float, double, complex64,
complex128);
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Acosh", functor::acosh, float, double);
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER2(UnaryOp, GPU, "Acosh", functor::acosh, float, double);

View File

@ -44,26 +44,4 @@ REGISTER_KERNEL_BUILDER(Name("AddV2")
BinaryOp<CPUDevice, functor::add<int32>>);
#endif
#if TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(type) \
REGISTER(BinaryOp, SYCL, "Add", functor::add, type); \
REGISTER(BinaryOp, SYCL, "AddV2", functor::add, type);
TF_CALL_SYCL_NUMBER_TYPES(REGISTER_KERNEL);
REGISTER_KERNEL_BUILDER(Name("Add")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::add<int32>>);
REGISTER_KERNEL_BUILDER(Name("AddV2")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::add<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -22,7 +22,4 @@ REGISTER2(UnaryOp, CPU, "Asin", functor::asin, float, double);
REGISTER2(UnaryOp, GPU, "Asin", functor::asin, float, double);
#endif
#if TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Asin", functor::asin, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -20,9 +20,6 @@ namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Asinh", functor::asinh, float, double, complex64,
complex128);
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Asinh", functor::asinh, float, double);
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER2(UnaryOp, GPU, "Asinh", functor::asinh, float, double);

View File

@ -22,7 +22,4 @@ REGISTER2(UnaryOp, CPU, "Atan", functor::atan, float, double);
REGISTER2(UnaryOp, GPU, "Atan", functor::atan, float, double);
#endif
#if TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Atan", functor::atan, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -20,9 +20,6 @@ namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, complex64,
complex128);
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Atanh", functor::atanh, float, double);
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double);

View File

@ -19,22 +19,6 @@ namespace tensorflow {
REGISTER8(BinaryOp, CPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32,
int64, uint8, uint16, uint32, uint64);
#if TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BitwiseAnd").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::bitwise_and<TYPE>>);
REGISTER_SYCL_KERNEL(int8);
REGISTER_SYCL_KERNEL(int16);
REGISTER_SYCL_KERNEL(int32);
REGISTER_SYCL_KERNEL(int64);
REGISTER_SYCL_KERNEL(uint8);
REGISTER_SYCL_KERNEL(uint16);
REGISTER_SYCL_KERNEL(uint32);
REGISTER_SYCL_KERNEL(uint64);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER8(BinaryOp, GPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32,

View File

@ -19,22 +19,6 @@ namespace tensorflow {
REGISTER8(BinaryOp, CPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32,
int64, uint8, uint16, uint32, uint64);
#if TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BitwiseOr").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::bitwise_or<TYPE>>);
REGISTER_SYCL_KERNEL(int8);
REGISTER_SYCL_KERNEL(int16);
REGISTER_SYCL_KERNEL(int32);
REGISTER_SYCL_KERNEL(int64);
REGISTER_SYCL_KERNEL(uint8);
REGISTER_SYCL_KERNEL(uint16);
REGISTER_SYCL_KERNEL(uint32);
REGISTER_SYCL_KERNEL(uint64);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER8(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32,

View File

@ -19,22 +19,6 @@ namespace tensorflow {
REGISTER8(BinaryOp, CPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32,
int64, uint8, uint16, uint32, uint64);
#if TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BitwiseXor").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::bitwise_xor<TYPE>>);
REGISTER_SYCL_KERNEL(int8);
REGISTER_SYCL_KERNEL(int16);
REGISTER_SYCL_KERNEL(int32);
REGISTER_SYCL_KERNEL(int64);
REGISTER_SYCL_KERNEL(uint8);
REGISTER_SYCL_KERNEL(uint16);
REGISTER_SYCL_KERNEL(uint32);
REGISTER_SYCL_KERNEL(uint64);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER8(BinaryOp, GPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32,

View File

@ -23,7 +23,4 @@ REGISTER4(UnaryOp, CPU, "Ceil", functor::ceil, float, Eigen::half, bfloat16,
REGISTER3(UnaryOp, GPU, "Ceil", functor::ceil, float, Eigen::half, double);
#endif
#if TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Ceil", functor::ceil, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -23,7 +23,4 @@ REGISTER6(UnaryOp, CPU, "Cos", functor::cos, float, Eigen::half, bfloat16,
REGISTER3(UnaryOp, GPU, "Cos", functor::cos, float, Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Cos", functor::cos, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -19,15 +19,6 @@ namespace tensorflow {
REGISTER5(UnaryOp, CPU, "Cosh", functor::cosh, float, double, bfloat16,
complex64, complex128);
#if TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("Cosh").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
UnaryOp<SYCLDevice, functor::cosh<TYPE>>);
REGISTER_SYCL_KERNEL(float);
REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER2(UnaryOp, GPU, "Cosh", functor::cosh, float, double);

View File

@ -50,15 +50,4 @@ REGISTER_KERNEL_BUILDER(Name("Div")
BinaryOp<CPUDevice, functor::safe_div<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(BinaryOp, SYCL, "Div", functor::div, float, double);
REGISTER2(BinaryOp, SYCL, "RealDiv", functor::div, float, double);
REGISTER_KERNEL_BUILDER(Name("Div")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::safe_div<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -47,16 +47,5 @@ REGISTER_KERNEL_BUILDER(Name("Equal")
BinaryOp<CPUDevice, functor::equal_to<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER5(BinaryOp, SYCL, "Equal", functor::equal_to, float, double, uint8,
int8, int16);
REGISTER_KERNEL_BUILDER(Name("Equal")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::equal_to<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -24,7 +24,4 @@ REGISTER5(UnaryOp, GPU, "Exp", functor::exp, float, Eigen::half, double,
complex64, complex128);
#endif
#if TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Exp", functor::exp, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -21,7 +21,4 @@ REGISTER6(UnaryOp, CPU, "Expm1", functor::expm1, float, Eigen::half, bfloat16,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(UnaryOp, GPU, "Expm1", functor::expm1, float, Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Expm1", functor::expm1, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -22,7 +22,4 @@ REGISTER4(UnaryOp, CPU, "Floor", functor::floor, float, Eigen::half, bfloat16,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(UnaryOp, GPU, "Floor", functor::floor, float, Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Floor", functor::floor, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -41,13 +41,4 @@ REGISTER_KERNEL_BUILDER(Name("FloorDiv")
BinaryOp<CPUDevice, functor::safe_floor_div<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("FloorDiv")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::safe_floor_div<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -34,13 +34,4 @@ REGISTER_KERNEL_BUILDER(Name("FloorMod")
BinaryOp<CPUDevice, functor::safe_floor_mod<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("FloorMod")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::safe_floor_mod<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -33,15 +33,4 @@ REGISTER_KERNEL_BUILDER(Name("Greater")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::greater<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(BinaryOp, SYCL, "Greater", functor::greater, float, double);
REGISTER_KERNEL_BUILDER(Name("Greater")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::greater<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -34,16 +34,4 @@ REGISTER_KERNEL_BUILDER(Name("GreaterEqual")
BinaryOp<CPUDevice, functor::greater_equal<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(BinaryOp, SYCL, "GreaterEqual", functor::greater_equal, float,
double);
REGISTER_KERNEL_BUILDER(Name("GreaterEqual")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::greater_equal<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -19,10 +19,6 @@ namespace tensorflow {
REGISTER8(UnaryOp, CPU, "Invert", functor::invert, int8, int16, int32, int64,
uint8, uint16, uint32, uint64);
#ifdef TENSORFLOW_USE_SYCL
REGISTER6(UnaryOp, SYCL, "Invert", functor::invert, int8, int16, int32, int64,
uint8, uint16, uint32, uint64);
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER8(UnaryOp, GPU, "Invert", functor::invert, int8, int16, int32, int64,

View File

@ -24,7 +24,4 @@ REGISTER3(UnaryOp, GPU, "IsFinite", functor::isfinite, float, Eigen::half,
double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "IsFinite", functor::isfinite, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -23,7 +23,4 @@ REGISTER4(UnaryOp, CPU, "IsInf", functor::isinf, float, Eigen::half, bfloat16,
REGISTER3(UnaryOp, GPU, "IsInf", functor::isinf, float, Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "IsInf", functor::isinf, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -23,7 +23,4 @@ REGISTER4(UnaryOp, CPU, "IsNan", functor::isnan, float, Eigen::half, double,
REGISTER3(UnaryOp, GPU, "IsNan", functor::isnan, float, Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "IsNan", functor::isnan, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -19,22 +19,6 @@ namespace tensorflow {
REGISTER8(BinaryOp, CPU, "LeftShift", functor::left_shift, int8, int16, int32,
int64, uint8, uint16, uint32, uint64);
#if TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("LeftShift").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::left_shift<TYPE>>);
REGISTER_SYCL_KERNEL(int8);
REGISTER_SYCL_KERNEL(int16);
REGISTER_SYCL_KERNEL(int32);
REGISTER_SYCL_KERNEL(int64);
REGISTER_SYCL_KERNEL(uint8);
REGISTER_SYCL_KERNEL(uint16);
REGISTER_SYCL_KERNEL(uint32);
REGISTER_SYCL_KERNEL(uint64);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER8(BinaryOp, GPU, "LeftShift", functor::left_shift, int8, int16, int32,

View File

@ -35,14 +35,4 @@ REGISTER_KERNEL_BUILDER(Name("Less")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::less<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER3(BinaryOp, SYCL, "Less", functor::less, float, double, int64);
REGISTER_KERNEL_BUILDER(Name("Less")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::less<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -37,15 +37,4 @@ REGISTER_KERNEL_BUILDER(Name("LessEqual")
BinaryOp<CPUDevice, functor::less_equal<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER6(BinaryOp, SYCL, "LessEqual", functor::less_equal, float, double,
int64, uint8, int8, int16);
REGISTER_KERNEL_BUILDER(Name("LessEqual")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::less_equal<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -23,7 +23,4 @@ REGISTER6(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double,
REGISTER3(UnaryOp, GPU, "Log", functor::log, float, Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Log", functor::log, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -23,7 +23,4 @@ REGISTER6(UnaryOp, CPU, "Log1p", functor::log1p, float, Eigen::half, bfloat16,
REGISTER3(UnaryOp, GPU, "Log1p", functor::log1p, float, Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Log1p", functor::log1p, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -34,14 +34,4 @@ REGISTER_KERNEL_BUILDER(Name("Maximum")
BinaryOp<CPUDevice, functor::maximum<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER3(BinaryOp, SYCL, "Maximum", functor::maximum, float, double, int64);
REGISTER_KERNEL_BUILDER(Name("Maximum")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::maximum<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -34,15 +34,5 @@ REGISTER_KERNEL_BUILDER(Name("Minimum")
BinaryOp<CPUDevice, functor::minimum<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER3(BinaryOp, SYCL, "Minimum", functor::minimum, float, double, int64);
REGISTER_KERNEL_BUILDER(Name("Minimum")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::minimum<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -49,14 +49,4 @@ REGISTER5(BinaryOp, GPU, "MulNoNan", functor::mul_no_nan, Eigen::half, float,
double, complex64, complex128);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER3(BinaryOp, SYCL, "Mul", functor::mul, float, double, uint8);
REGISTER_KERNEL_BUILDER(Name("Mul")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::mul<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -18,15 +18,6 @@ limitations under the License.
namespace tensorflow {
REGISTER4(UnaryOp, CPU, "Neg", functor::neg, int8, int16, int32, int64);
#ifdef TENSORFLOW_USE_SYCL
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
REGISTER_KERNEL_BUILDER(Name("Neg")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.TypeConstraint<int32>("T"),
UnaryOp<CPUDevice, functor::neg<int32>>);
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);

View File

@ -35,16 +35,5 @@ REGISTER_KERNEL_BUILDER(Name("NotEqual")
BinaryOp<CPUDevice, functor::not_equal_to<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(BinaryOp, SYCL, "NotEqual", functor::not_equal_to, float, double);
REGISTER_KERNEL_BUILDER(Name("NotEqual")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::not_equal_to<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -24,7 +24,4 @@ REGISTER2(BinaryOp, CPU, "Pow", functor::safe_pow, int32, int64);
REGISTER4(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double,
int64);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(BinaryOp, SYCL, "Pow", functor::pow, float, double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -36,9 +36,6 @@ REGISTER6(UnaryOp, CPU, "Reciprocal", functor::inverse, float, Eigen::half,
REGISTER4(UnaryOp, GPU, "Reciprocal", functor::inverse, float, Eigen::half,
double, int64);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER(UnaryOp, SYCL, "Reciprocal", functor::inverse, float);
#endif // TENSORFLOW_USE_SYCL
REGISTER6(SimpleBinaryOp, CPU, "ReciprocalGrad", functor::inverse_grad, float,
Eigen::half, bfloat16, double, complex64, complex128);
@ -46,7 +43,4 @@ REGISTER6(SimpleBinaryOp, CPU, "ReciprocalGrad", functor::inverse_grad, float,
REGISTER3(SimpleBinaryOp, GPU, "ReciprocalGrad", functor::inverse_grad, float,
Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER(SimpleBinaryOp, SYCL, "ReciprocalGrad", functor::inverse_grad, float);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -19,22 +19,6 @@ namespace tensorflow {
REGISTER8(BinaryOp, CPU, "RightShift", functor::right_shift, int8, int16, int32,
int64, uint8, uint16, uint32, uint64);
#if TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNEL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("RightShift").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
BinaryOp<SYCLDevice, functor::right_shift<TYPE>>);
REGISTER_SYCL_KERNEL(int8);
REGISTER_SYCL_KERNEL(int16);
REGISTER_SYCL_KERNEL(int32);
REGISTER_SYCL_KERNEL(int64);
REGISTER_SYCL_KERNEL(uint8);
REGISTER_SYCL_KERNEL(uint16);
REGISTER_SYCL_KERNEL(uint32);
REGISTER_SYCL_KERNEL(uint64);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER8(BinaryOp, GPU, "RightShift", functor::right_shift, int8, int16, int32,

View File

@ -19,9 +19,6 @@ namespace tensorflow {
REGISTER5(UnaryOp, CPU, "Round", functor::round, Eigen::half, float, double,
int32, int64);
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Round", functor::round, float, double);
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER5(UnaryOp, GPU, "Round", functor::round, Eigen::half, float, double,

View File

@ -22,9 +22,6 @@ REGISTER6(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, bfloat16,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(UnaryOp, SYCL, "Rsqrt", functor::rsqrt, float, double);
#endif // TENSORFLOW_USE_SYCL
REGISTER6(SimpleBinaryOp, CPU, "RsqrtGrad", functor::rsqrt_grad, float,
Eigen::half, bfloat16, double, complex64, complex128);
@ -32,8 +29,4 @@ REGISTER6(SimpleBinaryOp, CPU, "RsqrtGrad", functor::rsqrt_grad, float,
REGISTER3(SimpleBinaryOp, GPU, "RsqrtGrad", functor::rsqrt_grad, float,
Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(SimpleBinaryOp, SYCL, "RsqrtGrad", functor::rsqrt_grad, float,
double);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -29,9 +29,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
namespace functor {
template <typename Device, typename T>
@ -294,22 +291,6 @@ REGISTER_SELECT_GPU(complex128);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
// Registration of the SYCL implementations.
#define REGISTER_SELECT_SYCL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SelectOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SelectOp<SYCLDevice, type>);
REGISTER_SELECT_SYCL(float);
REGISTER_SELECT_SYCL(double);
REGISTER_SELECT_SYCL(int32);
REGISTER_SELECT_SYCL(int64);
#undef REGISTER_SELECT_SYCL
#endif // TENSORFLOW_USE_SYCL
namespace functor {
@ -326,10 +307,6 @@ struct SelectFunctorBase {
template <typename T>
struct SelectFunctor<CPUDevice, T> : SelectFunctorBase<CPUDevice, T> {};
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
struct SelectScalarHandler {
@ -364,21 +341,6 @@ struct SelectScalarHandler<CPUDevice, T> {
}
};
#ifdef TENSORFLOW_USE_SYCL
template <typename Device, typename T>
struct SelectScalarFunctorBase {
void operator()(const Device& d, typename TTypes<T>::Flat out,
TTypes<bool>::ConstScalar cond,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {
out.device(d) = cond() ? then_flat : else_flat;
}
};
template <typename T>
struct SelectScalarFunctor<SYCLDevice, T>
: SelectScalarFunctorBase<SYCLDevice, T> {};
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
struct BatchSelectFunctorBase {
@ -469,16 +431,6 @@ template <typename T, int NDIMS>
struct BCastSelectFunctor<CPUDevice, T, NDIMS>
: BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct BatchSelectFunctor<SYCLDevice, T>
: BatchSelectFunctorBase<SYCLDevice, T> {};
template <typename T, int NDIMS>
struct BCastSelectFunctor<SYCLDevice, T, NDIMS>
: BCastSelectFunctorBase<SYCLDevice, T, NDIMS> {};
#endif // TENSORFLOW_USE_SYCL
} // namespace functor

View File

@ -23,9 +23,6 @@ REGISTER6(UnaryOp, CPU, "Sigmoid", functor::sigmoid, bfloat16, float,
REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half,
double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER(UnaryOp, SYCL, "Sigmoid", functor::sigmoid, float);
#endif // TENSORFLOW_USE_SYCL
REGISTER6(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, bfloat16,
float, Eigen::half, double, complex64, complex128);
@ -33,8 +30,5 @@ REGISTER6(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, bfloat16,
REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float,
Eigen::half, double);
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER(SimpleBinaryOp, SYCL, "SigmoidGrad", functor::sigmoid_grad, float);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More