Merge changes from github.
Change: 138675832
This commit is contained in:
parent
f0e9bd3c55
commit
a771598ad8
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,6 +5,7 @@ node_modules
|
||||
/third_party/py/numpy/numpy_include
|
||||
/tools/bazel.rc
|
||||
/tools/python_bin_path.sh
|
||||
/tools/git/gen
|
||||
/util/python/python_include
|
||||
/util/python/python_lib
|
||||
/pip_test
|
||||
|
105
configure
vendored
105
configure
vendored
@ -53,6 +53,7 @@ if is_windows; then
|
||||
TF_NEED_GCP=0
|
||||
TF_NEED_HDFS=0
|
||||
TF_NEED_CUDA=0
|
||||
TF_NEED_OPENCL=0
|
||||
fi
|
||||
|
||||
while [ "$TF_NEED_GCP" == "" ]; do
|
||||
@ -116,6 +117,17 @@ GEN_GIT_SOURCE=tensorflow/tools/git/gen_git_source.py
|
||||
chmod a+x ${GEN_GIT_SOURCE}
|
||||
"${PYTHON_BIN_PATH}" ${GEN_GIT_SOURCE} --configure "${SOURCE_BASE_DIR}"
|
||||
|
||||
## Set up SYCL-related environment settings
|
||||
while [ "$TF_NEED_OPENCL" == "" ]; do
|
||||
read -p "Do you wish to build TensorFlow with OpenCL support? [y/N] " INPUT
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=1;;
|
||||
[Nn]* ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;;
|
||||
"" ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
esac
|
||||
done
|
||||
|
||||
## Set up Cuda-related environment settings
|
||||
|
||||
while [ "$TF_NEED_CUDA" == "" ]; do
|
||||
@ -129,12 +141,14 @@ while [ "$TF_NEED_CUDA" == "" ]; do
|
||||
done
|
||||
|
||||
export TF_NEED_CUDA
|
||||
if [ "$TF_NEED_CUDA" == "0" ]; then
|
||||
export TF_NEED_SYCL
|
||||
if [[ "$TF_NEED_CUDA" == "0" ]] && [[ "$TF_NEED_OPENCL" == "0" ]]; then
|
||||
echo "Configuration finished"
|
||||
bazel_clean_and_fetch
|
||||
exit
|
||||
fi
|
||||
|
||||
if [ "$TF_NEED_CUDA" == "1" ]; then
|
||||
# Set up which gcc nvcc should use as the host compiler
|
||||
while true; do
|
||||
fromuser=""
|
||||
@ -336,6 +350,95 @@ EOF
|
||||
TF_CUDA_COMPUTE_CAPABILITIES=""
|
||||
done
|
||||
|
||||
# end of if "$TF_NEED_CUDA" == "1"
|
||||
fi
|
||||
|
||||
# OpenCL configuration
|
||||
|
||||
if [ "$TF_NEED_OPENCL" == "1" ]; then
|
||||
|
||||
# Determine which C++ compiler should be used as the host compiler
|
||||
while true; do
|
||||
fromuser=""
|
||||
if [ -z "$HOST_CXX_COMPILER" ]; then
|
||||
default_cxx_host_compiler=$(which g++|| true)
|
||||
read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER
|
||||
fromuser="1"
|
||||
if [ -z "$HOST_CXX_COMPILER" ]; then
|
||||
HOST_CXX_COMPILER=$default_cxx_host_compiler
|
||||
fi
|
||||
fi
|
||||
if [ -e "$HOST_CXX_COMPILER" ]; then
|
||||
export HOST_CXX_COMPILER
|
||||
break
|
||||
fi
|
||||
echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2
|
||||
if [ -z "$fromuser" ]; then
|
||||
exit 1
|
||||
fi
|
||||
HOST_CXX_COMPILER=""
|
||||
# Retry
|
||||
done
|
||||
|
||||
# Determine which C compiler should be used as the host compiler
|
||||
while true; do
|
||||
fromuser=""
|
||||
if [ -z "$HOST_C_COMPILER" ]; then
|
||||
default_c_host_compiler=$(which gcc|| true)
|
||||
read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER
|
||||
fromuser="1"
|
||||
if [ -z "$HOST_C_COMPILER" ]; then
|
||||
HOST_C_COMPILER=$default_c_host_compiler
|
||||
fi
|
||||
fi
|
||||
if [ -e "$HOST_C_COMPILER" ]; then
|
||||
export HOST_C_COMPILER
|
||||
break
|
||||
fi
|
||||
echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2
|
||||
if [ -z "$fromuser" ]; then
|
||||
exit 1
|
||||
fi
|
||||
HOST_C_COMPILER=""
|
||||
# Retry
|
||||
done
|
||||
|
||||
while true; do
|
||||
# Configure the OPENCL version to use.
|
||||
TF_OPENCL_VERSION="1.2"
|
||||
|
||||
# Point to ComputeCpp root
|
||||
if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
|
||||
default_computecpp_toolkit_path=/usr/local/computecpp
|
||||
read -p "Please specify the location where ComputeCpp $TF_OPENCL_VERSION is installed. Refer to README.md for more details. [Default is $default_computecpp_toolkit_path]: " COMPUTECPP_TOOLKIT_PATH
|
||||
fromuser="1"
|
||||
if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
|
||||
COMPUTECPP_TOOLKIT_PATH=$default_computecpp_toolkit_path
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$OSNAME" == "Linux" ]; then
|
||||
SYCL_RT_LIB_PATH="lib/libComputeCpp.so"
|
||||
fi
|
||||
|
||||
if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then
|
||||
export COMPUTECPP_TOOLKIT_PATH
|
||||
break
|
||||
fi
|
||||
echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found"
|
||||
|
||||
if [ -z "$fromuser" ]; then
|
||||
exit 1
|
||||
fi
|
||||
# Retry
|
||||
TF_OPENCL_VERSION=""
|
||||
COMPUTECPP_TOOLKIT_PATH=""
|
||||
done
|
||||
|
||||
export TF_NEED_OPENCL
|
||||
# end of if "$TF_NEED_OPENCL" == "1"
|
||||
fi
|
||||
|
||||
bazel_clean_and_fetch
|
||||
|
||||
echo "Configuration finished"
|
||||
|
@ -15,13 +15,14 @@
|
||||
|
||||
# Bring in all of the public TensorFlow interface into this
|
||||
# module.
|
||||
# pylint: disable=wildcard-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python import *
|
||||
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
# Lazily import the `tf.contrib` module. This avoids loading all of the
|
||||
# dependencies of `tf.contrib` at `import tensorflow` time.
|
||||
|
@ -20,7 +20,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
from tensorflow.contrib.bayesflow.python.ops import entropy
|
||||
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
|
||||
from tensorflow.contrib.bayesflow.python.ops import special_math
|
||||
@ -29,3 +29,4 @@ from tensorflow.contrib.bayesflow.python.ops import stochastic_graph
|
||||
from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor
|
||||
from tensorflow.contrib.bayesflow.python.ops import stochastic_variables
|
||||
from tensorflow.contrib.bayesflow.python.ops import variational_inference
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
@ -22,7 +22,8 @@ option(tensorflow_BUILD_CC_EXAMPLE "Build the C++ tutorial example" ON)
|
||||
option(tensorflow_BUILD_PYTHON_BINDINGS "Build the Python bindings" ON)
|
||||
option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
|
||||
option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON)
|
||||
|
||||
option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF)
|
||||
option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF)
|
||||
|
||||
#Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option for
|
||||
# targets that link ${CMAKE_THREAD_LIBS_INIT}.
|
||||
@ -74,6 +75,9 @@ include(jsoncpp)
|
||||
include(farmhash)
|
||||
include(highwayhash)
|
||||
include(protobuf)
|
||||
if (tensorflow_BUILD_CC_TESTS)
|
||||
include(googletest)
|
||||
endif()
|
||||
|
||||
set(tensorflow_EXTERNAL_LIBRARIES
|
||||
${zlib_STATIC_LIBRARIES}
|
||||
@ -194,7 +198,6 @@ include(tf_core_kernels.cmake)
|
||||
if(tensorflow_ENABLE_GRPC_SUPPORT)
|
||||
include(tf_core_distributed_runtime.cmake)
|
||||
endif()
|
||||
|
||||
include(tf_cc_ops.cmake)
|
||||
if(tensorflow_BUILD_CC_EXAMPLE)
|
||||
include(tf_tutorials.cmake)
|
||||
@ -203,3 +206,6 @@ endif()
|
||||
if(tensorflow_BUILD_PYTHON_BINDINGS)
|
||||
include(tf_python.cmake)
|
||||
endif()
|
||||
if (tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS)
|
||||
include(tf_tests.cmake)
|
||||
endif()
|
||||
|
@ -60,7 +60,9 @@ Note: Windows support is in an **alpha** state, and we welcome your feedback.
|
||||
on Windows, but have not yet committed to supporting that configuration.)
|
||||
|
||||
- The following Python APIs are not currently implemented:
|
||||
* Loading custom op libraries via `tf.load_op_library()`.
|
||||
* Loading custom op libraries via `tf.load_op_library()`. In order to use your
|
||||
custom op, please put the source code under the tensorflow/core/user_ops
|
||||
directory, and a shape function is required (not optional) for each op.
|
||||
* Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not
|
||||
functional.
|
||||
|
||||
@ -76,7 +78,6 @@ Note: Windows support is in an **alpha** state, and we welcome your feedback.
|
||||
* `ImmutableConst`
|
||||
* `Lgamma`
|
||||
* `Polygamma`
|
||||
* `SparseMatmul`
|
||||
* `Zeta`
|
||||
|
||||
- Google Cloud Storage support is not currently implemented. The GCS library
|
||||
@ -195,7 +196,21 @@ Step-by-step Windows build
|
||||
* `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include
|
||||
GPU support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and CUDNN 5.1.
|
||||
CMake will expect the location of CUDNN in -DCUDNN_HOME=path_you_unziped_cudnn.
|
||||
|
||||
|
||||
* `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds cc unit tests.
|
||||
There are many of them and building will take a few hours.
|
||||
After cmake, build and execute the tests with
|
||||
```
|
||||
MSBuild /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj
|
||||
ctest -C RelWithDebInfo
|
||||
```
|
||||
|
||||
* `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python kernel tests.
|
||||
After building the python wheel, you need to install the new wheel before running the tests.
|
||||
To execute the tests, use
|
||||
```
|
||||
ctest -C RelWithDebInfo
|
||||
```
|
||||
|
||||
4. Invoke MSBuild to build TensorFlow.
|
||||
|
||||
|
29
tensorflow/contrib/cmake/external/googletest.cmake
vendored
Normal file
29
tensorflow/contrib/cmake/external/googletest.cmake
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
include (ExternalProject)
|
||||
|
||||
set(googletest_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/include)
|
||||
set(googletest_URL https://github.com/google/googletest.git)
|
||||
set(googletest_BUILD ${CMAKE_CURRENT_BINARY_DIR}/googletest/)
|
||||
set(googletest_TAG ec44c6c1675c25b9827aacd08c02433cccde7780)
|
||||
|
||||
if(WIN32)
|
||||
set(googletest_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/${CMAKE_BUILD_TYPE}/gtest.lib)
|
||||
else()
|
||||
set(googletest_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/${CMAKE_BUILD_TYPE}/gtest.a)
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(googletest
|
||||
PREFIX googletest
|
||||
GIT_REPOSITORY ${googletest_URL}
|
||||
GIT_TAG ${googletest_TAG}
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
BUILD_IN_SOURCE 1
|
||||
#PATCH_COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_SOURCE_DIR}/patches/grpc/CMakeLists.txt ${GRPC_BUILD}
|
||||
INSTALL_COMMAND ""
|
||||
CMAKE_CACHE_ARGS
|
||||
-DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}
|
||||
-DBUILD_GMOCK:BOOL=OFF
|
||||
-DBUILD_GTEST:BOOL=ON
|
||||
-Dgtest_force_shared_crt:BOOL=ON
|
||||
)
|
2
tensorflow/contrib/cmake/external/zlib.cmake
vendored
2
tensorflow/contrib/cmake/external/zlib.cmake
vendored
@ -8,7 +8,7 @@ set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d)
|
||||
|
||||
if(WIN32)
|
||||
set(zlib_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlib.lib)
|
||||
${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
|
||||
else()
|
||||
set(zlib_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a)
|
||||
|
@ -14,6 +14,7 @@ set(tf_op_lib_names
|
||||
"no_op"
|
||||
"parsing_ops"
|
||||
"random_ops"
|
||||
"resource_variable_ops"
|
||||
"script_ops"
|
||||
"sdca_ops"
|
||||
"sendrecv_ops"
|
||||
|
@ -270,6 +270,7 @@ GENERATE_PYTHON_OP_LIB("logging_ops")
|
||||
GENERATE_PYTHON_OP_LIB("nn_ops")
|
||||
GENERATE_PYTHON_OP_LIB("parsing_ops")
|
||||
GENERATE_PYTHON_OP_LIB("random_ops")
|
||||
GENERATE_PYTHON_OP_LIB("resource_variable_ops")
|
||||
GENERATE_PYTHON_OP_LIB("script_ops")
|
||||
GENERATE_PYTHON_OP_LIB("sdca_ops")
|
||||
GENERATE_PYTHON_OP_LIB("state_ops")
|
||||
|
384
tensorflow/contrib/cmake/tf_tests.cmake
Normal file
384
tensorflow/contrib/cmake/tf_tests.cmake
Normal file
@ -0,0 +1,384 @@
|
||||
enable_testing()
|
||||
|
||||
#
|
||||
# get a temp path for test data
|
||||
#
|
||||
function(GetTestRunPath VAR_NAME OBJ_NAME)
|
||||
if(WIN32)
|
||||
if(DEFINED ENV{TMP})
|
||||
set(TMPDIR "$ENV{TMP}")
|
||||
elseif(DEFINED ENV{TEMP})
|
||||
set(TMPDIR "$ENV{TEMP}")
|
||||
endif()
|
||||
string(REPLACE "\\" "/" TMPDIR ${TMPDIR})
|
||||
else()
|
||||
set(TMPDIR "$ENV{TMPDIR}")
|
||||
endif()
|
||||
if(NOT EXISTS "${TMPDIR}")
|
||||
message(FATAL_ERROR "Unable to determine a path to the temporary directory")
|
||||
endif()
|
||||
set(${VAR_NAME} "${TMPDIR}/${OBJ_NAME}" PARENT_SCOPE)
|
||||
endfunction(GetTestRunPath)
|
||||
|
||||
#
|
||||
# create test for each source
|
||||
#
|
||||
function(AddTests)
|
||||
cmake_parse_arguments(_AT "" "" "SOURCES;OBJECTS;LIBS;DATA;DEPENDS" ${ARGN})
|
||||
foreach(sourcefile ${_AT_SOURCES})
|
||||
string(REPLACE "${tensorflow_source_dir}/" "" exename ${sourcefile})
|
||||
string(REPLACE ".cc" "" exename ${exename})
|
||||
string(REPLACE "/" "_" exename ${exename})
|
||||
AddTest(
|
||||
TARGET ${exename}
|
||||
SOURCES ${sourcefile}
|
||||
OBJECTS ${_AT_OBJECTS}
|
||||
LIBS ${_AT_LIBS}
|
||||
DATA ${_AT_DATA}
|
||||
DEPENDS ${_AT_DEPENDS}
|
||||
)
|
||||
endforeach()
|
||||
endfunction(AddTests)
|
||||
|
||||
#
|
||||
# create once test
|
||||
#
|
||||
function(AddTest)
|
||||
cmake_parse_arguments(_AT "" "TARGET" "SOURCES;OBJECTS;LIBS;DATA;DEPENDS" ${ARGN})
|
||||
|
||||
list(REMOVE_DUPLICATES _AT_SOURCES)
|
||||
list(REMOVE_DUPLICATES _AT_OBJECTS)
|
||||
list(REMOVE_DUPLICATES _AT_LIBS)
|
||||
if (_AT_DATA)
|
||||
list(REMOVE_DUPLICATES _AT_DATA)
|
||||
endif(_AT_DATA)
|
||||
if (_AT_DEPENDS)
|
||||
list(REMOVE_DUPLICATES _AT_DEPENDS)
|
||||
endif(_AT_DEPENDS)
|
||||
|
||||
add_executable(${_AT_TARGET} ${_AT_SOURCES} ${_AT_OBJECTS})
|
||||
target_link_libraries(${_AT_TARGET} ${_AT_LIBS})
|
||||
|
||||
GetTestRunPath(testdir ${_AT_TARGET})
|
||||
set(tempdir "${testdir}/tmp")
|
||||
file(REMOVE_RECURSE "${testdir}")
|
||||
file(MAKE_DIRECTORY "${testdir}")
|
||||
file(MAKE_DIRECTORY "${tempdir}")
|
||||
add_test(NAME ${_AT_TARGET} COMMAND ${_AT_TARGET} WORKING_DIRECTORY "${testdir}")
|
||||
set_tests_properties(${_AT_TARGET}
|
||||
PROPERTIES ENVIRONMENT "TEST_TMPDIR=${tempdir};TEST_SRCDIR=${testdir}"
|
||||
)
|
||||
|
||||
foreach(datafile ${_AT_DATA})
|
||||
add_custom_command(
|
||||
TARGET ${_AT_TARGET} POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${datafile}"
|
||||
"${testdir}/${datafile}"
|
||||
DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/${datafile}"
|
||||
)
|
||||
endforeach()
|
||||
|
||||
if (_AT_DEPENDS)
|
||||
add_dependencies(${_AT_TARGET} ${_AT_DEPENDS})
|
||||
endif()
|
||||
endfunction(AddTest)
|
||||
|
||||
#
|
||||
# create python test for each script
|
||||
#
|
||||
function(AddPythonTests)
|
||||
cmake_parse_arguments(_AT "" "" "SOURCES;DATA;DEPENDS" ${ARGN})
|
||||
list(REMOVE_DUPLICATES _AT_SOURCES)
|
||||
if (_AT_DATA)
|
||||
list(REMOVE_DUPLICATES _AT_DATA)
|
||||
endif(_AT_DATA)
|
||||
if (_AT_DEPENDS)
|
||||
list(REMOVE_DUPLICATES _AT_DEPENDS)
|
||||
endif(_AT_DEPENDS)
|
||||
|
||||
foreach(sourcefile ${_AT_SOURCES})
|
||||
add_test(NAME ${sourcefile} COMMAND ${PYTHON_EXECUTABLE} ${sourcefile})
|
||||
if (_AT_DEPENDS)
|
||||
add_dependencies(${_AT_TARGET} ${_AT_DEPENDS})
|
||||
endif()
|
||||
endforeach()
|
||||
endfunction(AddPythonTests)
|
||||
|
||||
if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
#
|
||||
# python tests. This assumes that the tensorflow wheel is
|
||||
# installed on the test system.
|
||||
# TODO: we currently don't handle tests that need to have
|
||||
# some environment setup: see AddTest how to add this
|
||||
#
|
||||
|
||||
# include all test
|
||||
file(GLOB_RECURSE tf_test_src_py
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py"
|
||||
)
|
||||
|
||||
# exclude the onces we don't want
|
||||
set(tf_test_src_py_exclude
|
||||
# generally not working
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/resource_variable_ops_test.py"
|
||||
)
|
||||
if (WIN32)
|
||||
set(tf_test_src_py_exclude
|
||||
${tf_test_src_py_exclude}
|
||||
# generally excluded
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py"
|
||||
|
||||
# TODO: failing tests.
|
||||
# Nothing critical in here but should get this list down to []
|
||||
# The failing list is grouped by failure source
|
||||
# stl on windows handles overflows different
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/cast_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/clip_ops_test.py"
|
||||
# misc
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/reshape_op_test.py"
|
||||
# int32/int64 mixup
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py"
|
||||
# issues related to windows fs
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/io_ops_test.py"
|
||||
# missing kernel
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/pooling_ops_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/conv_ops_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/depthwise_conv_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/pool_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py"
|
||||
# cuda launch failed
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/diag_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/trace_op_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/one_hot_op_test.py" # gpu, T=uint8
|
||||
)
|
||||
endif()
|
||||
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
|
||||
|
||||
AddPythonTests(
|
||||
SOURCES ${tf_test_src_py}
|
||||
)
|
||||
endif(tensorflow_BUILD_PYTHON_TESTS)
|
||||
|
||||
if (tensorflow_BUILD_CC_TESTS)
|
||||
#
|
||||
# cc unit tests. Be aware that by default we include 250+ tests which
|
||||
# will take time and space to build.
|
||||
# If you wan to cut this down, for example to a specific test, modify
|
||||
# tf_test_src_simple to your needs
|
||||
#
|
||||
|
||||
include_directories(${googletest_INCLUDE_DIRS})
|
||||
|
||||
# cc tests wrapper
|
||||
set(tf_src_testlib
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/testutil.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/gradients/grad_testutil.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/framework/function_testlib.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/framework/shape_inference_testutil.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/framework/tensor_testutil.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/graph/testlib.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/test_main.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/default/test_benchmark.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"
|
||||
)
|
||||
|
||||
# include all test
|
||||
file(GLOB_RECURSE tf_test_src_simple
|
||||
"${tensorflow_source_dir}/tensorflow/cc/*_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/python/*_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/*_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/user_ops/*_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/*_test.cc"
|
||||
)
|
||||
|
||||
if (NOT tensorflow_ENABLE_GPU)
|
||||
# exclude gpu tests if we are not buildig for gpu
|
||||
set(tf_test_src_simple_exclude
|
||||
${tf_test_src_simple_exclude}
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc"
|
||||
)
|
||||
endif()
|
||||
|
||||
# exclude the onces we don't want
|
||||
set(tf_test_src_simple_exclude
|
||||
# generally not working
|
||||
"${tensorflow_source_dir}/tensorflow/cc/client/client_session_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/call_options_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/tensor_coding_test.cc"
|
||||
)
|
||||
|
||||
if (WIN32)
|
||||
set(tf_test_src_simple_exclude
|
||||
${tf_test_src_simple_exclude}
|
||||
# generally excluded
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/cc_ops_test.cc" # test_op.h missing
|
||||
|
||||
# TODO: test failing
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/simple_placer_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/executor_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantized_reshape_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/requantization_range_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/requantize_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/restore_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/restore_v2_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/save_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/sparse_reduce_sum_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/restore_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantize_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/core/status_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/strings/str_util_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/strings/numbers_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/monitoring/collection_registry_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/util/tensor_slice_reader_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/file_system_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/logging_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/env_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/ops/math_grad_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/cudnn_rnn_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops_test.cc" # status 5
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops_test.cc" # status 5
|
||||
|
||||
# TODO: not compiling
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/gradient_checker_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/gradients/math_grad_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/gradients/array_grad_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/saved_model/loader_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/training/queue_runner_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/training/coordinator_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/nn_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantization_utils_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/activation_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/batch_norm_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/bias_add_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/concat_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/conv_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/matmul_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/pooling_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantize_down_and_shrink_range_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/parameterized_truncated_normal_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/non_max_suppression_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/fused_batch_norm_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/hexagon_graph_transferer_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/adjust_contrast_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/batch_norm_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/cast_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/colorspace_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/control_flow_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/conv_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/debug_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/resize_bilinear_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/resize_nearest_neighbor_op_benchmark_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/spacetobatch_benchmark_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/sparse_add_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/summary_image_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/summary_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantized_activation_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantized_bias_add_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantized_concat_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantized_conv_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantized_matmul_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantized_pooling_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/kernels/quantized_batch_norm_op_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/cloud/gcs_file_system_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/cloud/google_auth_provider_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/cloud/http_request_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/cloud/oauth_client_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/cloud/retrying_file_system_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/cloud/time_util_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/port_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/platform/profile_utils/cpu_utils_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/master_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/remote_device_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/master_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/framework/partial_tensor_shape_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/core/notification_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/gtl/cleanup_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/gtl/edit_distance_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/lib/strings/strcat_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/ops/array_grad_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/ops/nn_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/example/example_parser_configuration_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/example/feature_util_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/util/reporter_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/util/memmapped_file_system_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/util/sparse_sparse_tensor_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/debug/debug_gateway_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/core/debug/debug_io_utils_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/bundle_shim_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/bundle_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/signature_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/training_ops_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/tree_utils_test.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/data/sparse_values_to_indices_test.cc"
|
||||
)
|
||||
endif()
|
||||
|
||||
list(REMOVE_ITEM tf_test_src_simple ${tf_test_src_simple_exclude})
|
||||
|
||||
set(tf_test_lib tf_test_lib)
|
||||
add_library(${tf_test_lib} STATIC ${tf_src_testlib})
|
||||
|
||||
# this is giving to much objects and libraries to the linker but
|
||||
# it makes this script much easier. So for now we do it this way.
|
||||
set(tf_obj_test
|
||||
$<TARGET_OBJECTS:tf_core_lib>
|
||||
$<TARGET_OBJECTS:tf_core_cpu>
|
||||
$<TARGET_OBJECTS:tf_core_framework>
|
||||
$<TARGET_OBJECTS:tf_core_kernels>
|
||||
$<TARGET_OBJECTS:tf_cc_framework>
|
||||
$<TARGET_OBJECTS:tf_cc_ops>
|
||||
$<TARGET_OBJECTS:tf_core_ops>
|
||||
$<TARGET_OBJECTS:tf_core_direct_session>
|
||||
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
|
||||
)
|
||||
|
||||
set(tf_test_libs
|
||||
tf_protos_cc
|
||||
tf_test_lib
|
||||
${tf_core_gpu_kernels_lib}
|
||||
${googletest_STATIC_LIBRARIES}
|
||||
${tensorflow_EXTERNAL_LIBRARIES}
|
||||
)
|
||||
|
||||
AddTests(
|
||||
SOURCES ${tf_test_src_simple}
|
||||
OBJECTS ${tf_obj_test}
|
||||
LIBS ${tf_test_libs}
|
||||
DEPENDS googletest
|
||||
)
|
||||
endif(tensorflow_BUILD_CC_TESTS)
|
@ -126,3 +126,5 @@ from tensorflow.contrib.distributions.python.ops.student_t import *
|
||||
from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
|
||||
from tensorflow.contrib.distributions.python.ops.uniform import *
|
||||
from tensorflow.contrib.distributions.python.ops.wishart import *
|
||||
|
||||
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
|
||||
|
@ -24,3 +24,4 @@ from tensorflow.contrib.factorization.python.ops.factorization_ops import *
|
||||
from tensorflow.contrib.factorization.python.ops.gmm import *
|
||||
from tensorflow.contrib.factorization.python.ops.gmm_ops import *
|
||||
from tensorflow.contrib.factorization.python.ops.kmeans import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -27,6 +27,7 @@ import tensorflow as tf
|
||||
|
||||
# pylint: disable=wildcard-import,undefined-variable
|
||||
from tensorflow.contrib.factorization.python.ops.gen_factorization_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
|
@ -73,5 +73,6 @@ import sys
|
||||
from tensorflow.contrib.framework.python.framework import *
|
||||
from tensorflow.contrib.framework.python.ops import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
@ -22,7 +22,9 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.framework.python.framework.checkpoint_utils import *
|
||||
from tensorflow.contrib.framework.python.framework.experimental import experimental
|
||||
from tensorflow.contrib.framework.python.framework.tensor_util import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util import decorator_utils
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.deprecation import deprecated_arg_values
|
||||
from tensorflow.python.util.deprecation import deprecated_args
|
||||
|
||||
|
@ -25,3 +25,4 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||
from tensorflow.contrib.grid_rnn.python.ops.grid_rnn_cell import *
|
||||
# pylint: enable=unused-import,wildcard-import,line-too-long
|
||||
|
@ -116,5 +116,6 @@ import sys
|
||||
from tensorflow.contrib.layers.python.layers import *
|
||||
from tensorflow.contrib.layers.python.ops import sparse_ops
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
@ -31,3 +31,4 @@ from tensorflow.contrib.layers.python.layers.summaries import *
|
||||
from tensorflow.contrib.layers.python.layers.target_column import *
|
||||
from tensorflow.contrib.layers.python.ops.bucketization_op import *
|
||||
from tensorflow.contrib.layers.python.ops.sparse_feature_cross_op import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -62,6 +62,7 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.learn.python.learn import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
@ -21,3 +21,4 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.learn.python.learn import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -24,3 +24,4 @@ from tensorflow.contrib.learn.python.learn.ops.array_ops import *
|
||||
from tensorflow.contrib.learn.python.learn.ops.embeddings_ops import *
|
||||
from tensorflow.contrib.learn.python.learn.ops.losses_ops import *
|
||||
from tensorflow.contrib.learn.python.learn.ops.seq2seq_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -22,3 +22,4 @@ from __future__ import print_function
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.learn.python.learn.preprocessing.categorical import *
|
||||
from tensorflow.contrib.learn.python.learn.preprocessing.text import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -35,3 +35,4 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.lookup.lookup_ops import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
@ -23,3 +23,4 @@ import sys
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.losses.python.losses import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
@ -133,5 +133,6 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
@ -117,7 +117,7 @@ attached Android device:
|
||||
adb push ~/graphs/inception/tensorflow_inception_graph.pb /data/local/tmp/
|
||||
adb push tensorflow/contrib/makefile/gen/bin/benchmark /data/local/tmp/
|
||||
adb shell '/data/local/tmp/benchmark \
|
||||
--graph=/data/local/tmp/classify_image_graph_def.pb \
|
||||
--graph=/data/local/tmp/tensorflow_inception_graph.pb \
|
||||
--input_layer="input:0" \
|
||||
--input_layer_shape="1,224,224,3" \
|
||||
--input_layer_type="float" \
|
||||
@ -190,7 +190,7 @@ tensorflow/contrib/makefile/download_dependencies.sh
|
||||
Next, you will need to compile protobufs for iOS:
|
||||
|
||||
```bash
|
||||
compile_ios_protobuf.sh
|
||||
tensorflow/contrib/makefile/compile_ios_protobuf.sh
|
||||
```
|
||||
|
||||
Then, you can run the makefile specifying iOS as the target, along with the
|
||||
|
@ -142,6 +142,7 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import
|
||||
from tensorflow.contrib.metrics.python.metrics import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion_matrix
|
||||
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
|
||||
@ -176,6 +177,6 @@ from tensorflow.contrib.metrics.python.ops.set_ops import set_intersection
|
||||
from tensorflow.contrib.metrics.python.ops.set_ops import set_size
|
||||
from tensorflow.contrib.metrics.python.ops.set_ops import set_union
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
@ -20,3 +20,4 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.metrics.python.metrics.classification import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -22,3 +22,4 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.ndlstm.python.lstm1d import *
|
||||
from tensorflow.contrib.ndlstm.python.lstm2d import *
|
||||
from tensorflow.contrib.ndlstm.python.misc import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -22,3 +22,4 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.opt.python.training.external_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -28,3 +28,4 @@ from tensorflow.python.ops import gen_array_ops as quantized_gen_array_ops
|
||||
from tensorflow.python.ops.gen_array_ops import dequantize
|
||||
from tensorflow.python.ops.gen_array_ops import quantize_v2
|
||||
from tensorflow.python.ops.gen_array_ops import quantized_concat
|
||||
# pylint: enable=unused-import,wildcard-import,g-bad-import-order
|
||||
|
@ -22,3 +22,4 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.quantization.python.array_ops import *
|
||||
from tensorflow.contrib.quantization.python.math_ops import *
|
||||
from tensorflow.contrib.quantization.python.nn_ops import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
@ -18,8 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.ops import gen_array_ops as quantized_gen_array_ops
|
||||
from tensorflow.python.ops.gen_array_ops import dequantize
|
||||
from tensorflow.python.ops.gen_array_ops import quantize_v2
|
||||
from tensorflow.python.ops.gen_array_ops import quantized_concat
|
||||
# pylint: enable=unused-import
|
||||
|
@ -23,3 +23,4 @@ from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops.gen_math_ops import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
@ -23,3 +23,4 @@ from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops.gen_nn_ops import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
@ -45,3 +45,4 @@ from tensorflow.contrib.rnn.python.ops.gru_ops import *
|
||||
from tensorflow.contrib.rnn.python.ops.lstm_ops import *
|
||||
from tensorflow.contrib.rnn.python.ops.rnn import *
|
||||
from tensorflow.contrib.rnn.python.ops.rnn_cell import *
|
||||
# pylint: enable=unused-import,wildcard-import,line-too-long
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
from tensorflow.contrib.seq2seq.python.ops import layers
|
||||
from tensorflow.contrib.seq2seq.python.ops import loss
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
@ -24,3 +24,4 @@ from tensorflow.contrib.specs.python.specs import *
|
||||
from tensorflow.contrib.specs.python.specs_lib import *
|
||||
from tensorflow.contrib.specs.python.specs_ops import *
|
||||
from tensorflow.contrib.specs.python.summaries import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -21,3 +21,4 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.tensor_forest.client import *
|
||||
from tensorflow.contrib.tensor_forest.data import *
|
||||
from tensorflow.contrib.tensor_forest.python import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
@ -17,5 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.tensor_forest.client import eval_metrics
|
||||
# pylint: enable=unused-import
|
||||
|
@ -17,5 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.tensor_forest.data import data_ops
|
||||
# pylint: enable=unused-import
|
||||
|
@ -19,3 +19,4 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.tensor_forest.hybrid.python import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
@ -21,3 +21,4 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.testing.python.framework.fake_summary_writer import *
|
||||
from tensorflow.contrib.testing.python.framework.util_test import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
@ -77,5 +77,6 @@ from tensorflow.contrib.training.python.training.training import create_train_op
|
||||
from tensorflow.contrib.training.python.training.training import multiply_gradients
|
||||
from tensorflow.contrib.training.python.training.training import train
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
@ -512,6 +512,7 @@ cc_library(
|
||||
deps = [
|
||||
":core_cpu",
|
||||
":gpu_runtime",
|
||||
":sycl_runtime",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1389,6 +1390,33 @@ tf_cuda_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sycl_runtime",
|
||||
srcs = if_not_windows([
|
||||
"common_runtime/sycl/sycl_device.cc",
|
||||
"common_runtime/sycl/sycl_device_context.cc",
|
||||
"common_runtime/sycl/sycl_device_factory.cc",
|
||||
]),
|
||||
hdrs = if_not_windows([
|
||||
"common_runtime/sycl/sycl_device.h",
|
||||
"common_runtime/sycl/sycl_device_context.h",
|
||||
]),
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":protos_all_cc",
|
||||
"//third_party/eigen3",
|
||||
#"@local_config_sycl//sycl:sycl",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
|
||||
|
@ -68,12 +68,17 @@ TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
|
||||
(std::vector<DeviceType>{DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
|
||||
types());
|
||||
|
||||
AddDevice("SYCL", "/job:a/replica:0/task:0/device:sycl:0");
|
||||
EXPECT_EQ(
|
||||
(std::vector<DeviceType>{DeviceType(DEVICE_SYCL), DeviceType(DEVICE_GPU),
|
||||
DeviceType(DEVICE_CPU)}), types());
|
||||
|
||||
AddDevice("T1", "/job:a/replica:0/task:0/device:T1:0");
|
||||
AddDevice("T1", "/job:a/replica:0/task:0/device:T1:1");
|
||||
AddDevice("T2", "/job:a/replica:0/task:0/device:T2:0");
|
||||
EXPECT_EQ(
|
||||
(std::vector<DeviceType>{DeviceType("T1"), DeviceType("T2"),
|
||||
DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
|
||||
(std::vector<DeviceType>{DeviceType(DEVICE_SYCL), DeviceType("T1"),
|
||||
DeviceType("T2"), DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
|
||||
types());
|
||||
}
|
||||
|
||||
|
@ -818,6 +818,8 @@ class BlockingOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
|
||||
REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc("");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_SYCL), BlockingOp);
|
||||
|
||||
static void TestSessionInterOpThreadsImpl(bool use_function_lib) {
|
||||
FunctionDefLibrary library_graph_def;
|
||||
if (use_function_lib) {
|
||||
|
88
tensorflow/core/common_runtime/sycl/sycl_device.cc
Normal file
88
tensorflow/core/common_runtime/sycl/sycl_device.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
cl::sycl::gpu_selector s;
|
||||
cl::sycl::queue q(s);
|
||||
|
||||
SYCLDevice::SYCLDevice(const SessionOptions& options, const string& name,
|
||||
Bytes memory_limit, const DeviceLocality& locality,
|
||||
const string& physical_device_desc, Allocator* allocator)
|
||||
: LocalDevice(options,
|
||||
Device::BuildDeviceAttributes(name, DEVICE_SYCL, memory_limit,
|
||||
locality, physical_device_desc),
|
||||
allocator),
|
||||
allocator_(allocator),
|
||||
device_context_(new SYCLDeviceContext()),
|
||||
device_(q) {
|
||||
set_eigen_sycl_device(&device_);
|
||||
}
|
||||
|
||||
SYCLDevice::~SYCLDevice() {
|
||||
device_context_->Unref();
|
||||
}
|
||||
|
||||
void SYCLDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
assert(context);
|
||||
if (port::Tracing::IsActive()) {
|
||||
// TODO(pbar) We really need a useful identifier of the graph node.
|
||||
const uint64 id = Hash64(op_kernel->name());
|
||||
port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
|
||||
id);
|
||||
}
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
Allocator* SYCLDevice::GetAllocator(AllocatorAttributes attr) {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
Status SYCLDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
ProtoDebugString(tensor_proto));
|
||||
}
|
||||
*tensor = std::move(parsed);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SYCLDevice::FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) {
|
||||
// Fill in the context map. It is OK for this map to contain
|
||||
// duplicate DeviceContexts so long as we increment the refcount.
|
||||
device_context_map->resize(graph->num_node_ids());
|
||||
for (Node* n : graph->nodes()) {
|
||||
device_context_->Ref();
|
||||
(*device_context_map)[n->id()] = device_context_;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
62
tensorflow/core/common_runtime/sycl/sycl_device.h
Normal file
62
tensorflow/core/common_runtime/sycl/sycl_device.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if !TENSORFLOW_USE_SYCL
|
||||
#error This file must only be included when building TensorFlow with SYCL support
|
||||
#endif
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
|
||||
|
||||
#define EIGEN_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class SYCLDevice : public LocalDevice {
|
||||
public:
|
||||
SYCLDevice(const SessionOptions& options, const string& name,
|
||||
Bytes memory_limit, const DeviceLocality& locality,
|
||||
const string& physical_device_desc, Allocator* allocator);
|
||||
~SYCLDevice() override;
|
||||
|
||||
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override;
|
||||
Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
const AllocatorAttributes alloc_attrs,
|
||||
Tensor* tensor) override;
|
||||
|
||||
Status FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) override;
|
||||
|
||||
Status Sync() override { return Status::OK(); }
|
||||
static string GetShortDeviceDescription(/*int device_id,
|
||||
const DeviceDescription& desc*/) {
|
||||
return strings::StrCat("device: 0, name SYCL, pci bus id: 0");
|
||||
}
|
||||
|
||||
private:
|
||||
Allocator* allocator_; // Not owned
|
||||
SYCLDeviceContext* device_context_;
|
||||
Eigen::SyclDevice device_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_DEVICE_H_
|
46
tensorflow/core/common_runtime/sycl/sycl_device_context.cc
Normal file
46
tensorflow/core/common_runtime/sycl/sycl_device_context.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const {
|
||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
const void* src_ptr = DMAHelper::base(cpu_tensor);
|
||||
void* dst_ptr = DMAHelper::base(device_tensor);
|
||||
::memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
const int64 total_bytes = device_tensor->TotalBytes();
|
||||
if (total_bytes > 0) {
|
||||
const void* src_ptr = DMAHelper::base(device_tensor);
|
||||
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
||||
::memcpy(dst_ptr, src_ptr, total_bytes);
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
42
tensorflow/core/common_runtime/sycl/sycl_device_context.h
Normal file
42
tensorflow/core/common_runtime/sycl/sycl_device_context.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
|
||||
#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class SYCLDeviceContext : public DeviceContext {
|
||||
public:
|
||||
SYCLDeviceContext() {}
|
||||
|
||||
~SYCLDeviceContext() override {}
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const override;
|
||||
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) override;
|
||||
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_
|
44
tensorflow/core/common_runtime/sycl/sycl_device_factory.cc
Normal file
44
tensorflow/core/common_runtime/sycl/sycl_device_factory.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/common_runtime/sycl/sycl_device.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class SYCLDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<Device*>* devices) override {
|
||||
int n = 1;
|
||||
auto iter = options.config.device_count().find("SYCL");
|
||||
if (iter != options.config.device_count().end()) {
|
||||
n = iter->second;
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
string name = strings::StrCat(name_prefix, "/device:SYCL:", i);
|
||||
devices->push_back(new SYCLDevice(
|
||||
options, name, Bytes(256 << 20), DeviceLocality(),
|
||||
SYCLDevice::GetShortDeviceDescription(), cpu_allocator()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
@ -30,6 +30,9 @@ limitations under the License.
|
||||
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
struct SyclDevice;
|
||||
#endif
|
||||
} // end namespace Eigen
|
||||
|
||||
namespace perftools {
|
||||
@ -145,6 +148,10 @@ class DeviceBase {
|
||||
eigen_cpu_device_ = d;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; }
|
||||
#endif
|
||||
|
||||
// Return the Allocator implementation to use based on the allocator
|
||||
// attributes requested. See allocator.h for more details.
|
||||
virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
|
||||
@ -167,6 +174,13 @@ class DeviceBase {
|
||||
return eigen_cpu_device_;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
const Eigen::SyclDevice* eigen_sycl_device() const {
|
||||
CHECK(eigen_sycl_device_ != nullptr);
|
||||
return eigen_sycl_device_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Caller owns the return value. The OpKernelContext calls this even
|
||||
// for devices that do not implement an eigen_gpu_device. Overridden
|
||||
// by GPU devices to return a derived type.
|
||||
@ -203,6 +217,9 @@ class DeviceBase {
|
||||
CpuWorkerThreads* cpu_worker_threads_ = nullptr;
|
||||
GpuDeviceInfo* gpu_device_info_ = nullptr;
|
||||
Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
Eigen::SyclDevice* eigen_sycl_device_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -949,6 +949,13 @@ const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
|
||||
return eigen_gpu_device();
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <>
|
||||
const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
|
||||
return eigen_sycl_device();
|
||||
}
|
||||
#endif
|
||||
|
||||
void OpKernelConstruction::CtxFailure(Status s) {
|
||||
VLOG(1) << s;
|
||||
SetStatus(s);
|
||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
struct GpuDevice;
|
||||
struct SyclDevice;
|
||||
} // end namespace Eigen
|
||||
|
||||
namespace tensorflow {
|
||||
@ -891,6 +892,11 @@ class OpKernelContext {
|
||||
const Eigen::GpuDevice& eigen_gpu_device() const {
|
||||
return params_->eigen_gpu_device->device();
|
||||
}
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
const Eigen::SyclDevice& eigen_sycl_device() const {
|
||||
return *device()->eigen_sycl_device();
|
||||
}
|
||||
#endif
|
||||
template <typename EigenDeviceType>
|
||||
const EigenDeviceType& eigen_device() const;
|
||||
|
||||
|
@ -721,14 +721,58 @@ bool Tensor::CanUseDMA() const {
|
||||
#undef CASE
|
||||
|
||||
namespace {
|
||||
// Print from left dim to right dim recursively.
|
||||
template <typename T>
|
||||
string SummarizeArray(int64 limit, int64 num_elts, const char* data) {
|
||||
void PrintOneDim(int dim_index, gtl::InlinedVector<int64, 4> shape, int64 limit,
|
||||
int shape_size, T* data, int64* data_index, string* result) {
|
||||
if (*data_index >= limit) return;
|
||||
int64 element_count = shape[dim_index];
|
||||
// We have reached the right-most dimension of the tensor.
|
||||
if (dim_index == shape_size - 1) {
|
||||
for (int64 i = 0; i < element_count; i++) {
|
||||
if (*data_index >= limit) return;
|
||||
if (i > 0) strings::StrAppend(result, " ");
|
||||
strings::StrAppend(result, data[(*data_index)++]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Loop every element of one dim.
|
||||
for (int64 i = 0; i < element_count; i++) {
|
||||
bool flag = false;
|
||||
if (*data_index < limit) {
|
||||
strings::StrAppend(result, "[");
|
||||
flag = true;
|
||||
}
|
||||
// As for each element, print the sub-dim.
|
||||
PrintOneDim(dim_index + 1, shape, limit, shape_size,
|
||||
data, data_index, result);
|
||||
if (*data_index < limit || flag) {
|
||||
strings::StrAppend(result, "]");
|
||||
flag = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
string SummarizeArray(int64 limit, int64 num_elts, const TensorShape& tensor_shape,
|
||||
const char* data) {
|
||||
string ret;
|
||||
const T* array = reinterpret_cast<const T*>(data);
|
||||
for (int64 i = 0; i < limit; ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, " ");
|
||||
strings::StrAppend(&ret, array[i]);
|
||||
|
||||
const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
|
||||
if(shape.empty()) {
|
||||
for (int64 i = 0; i < limit; ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, " ");
|
||||
strings::StrAppend(&ret, array[i]);
|
||||
}
|
||||
if (num_elts > limit) strings::StrAppend(&ret, "...");
|
||||
return ret;
|
||||
}
|
||||
int64 data_index = 0;
|
||||
const int shape_size = tensor_shape.dims();
|
||||
PrintOneDim(0, shape, limit, shape_size,
|
||||
array, &data_index, &ret);
|
||||
|
||||
if (num_elts > limit) strings::StrAppend(&ret, "...");
|
||||
return ret;
|
||||
}
|
||||
@ -744,40 +788,40 @@ string Tensor::SummarizeValue(int64 max_entries) const {
|
||||
const char* data = limit > 0 ? tensor_data().data() : nullptr;
|
||||
switch (dtype()) {
|
||||
case DT_HALF:
|
||||
return SummarizeArray<Eigen::half>(limit, num_elts, data);
|
||||
return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
return SummarizeArray<float>(limit, num_elts, data);
|
||||
return SummarizeArray<float>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
return SummarizeArray<double>(limit, num_elts, data);
|
||||
return SummarizeArray<double>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_INT32:
|
||||
return SummarizeArray<int32>(limit, num_elts, data);
|
||||
return SummarizeArray<int32>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
case DT_QUINT8:
|
||||
return SummarizeArray<uint8>(limit, num_elts, data);
|
||||
return SummarizeArray<uint8>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_UINT16:
|
||||
case DT_QUINT16:
|
||||
return SummarizeArray<uint16>(limit, num_elts, data);
|
||||
return SummarizeArray<uint16>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_INT16:
|
||||
case DT_QINT16:
|
||||
return SummarizeArray<int16>(limit, num_elts, data);
|
||||
return SummarizeArray<int16>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_INT8:
|
||||
case DT_QINT8:
|
||||
return SummarizeArray<int8>(limit, num_elts, data);
|
||||
return SummarizeArray<int8>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_INT64:
|
||||
return SummarizeArray<int64>(limit, num_elts, data);
|
||||
return SummarizeArray<int64>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
case DT_BOOL:
|
||||
// TODO(tucker): Is it better to emit "True False..."? This
|
||||
// will emit "1 0..." which is more compact.
|
||||
return SummarizeArray<bool>(limit, num_elts, data);
|
||||
return SummarizeArray<bool>(limit, num_elts, shape_, data);
|
||||
break;
|
||||
default: {
|
||||
// All irregular cases
|
||||
|
@ -834,20 +834,24 @@ TEST(SummarizeValue, INT32) {
|
||||
Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0});
|
||||
EXPECT_EQ("1 2 3 4 0", x.SummarizeValue(16));
|
||||
x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0});
|
||||
EXPECT_EQ("1 2 3 4", x.SummarizeValue(16));
|
||||
EXPECT_EQ("[1 2][3 4]", x.SummarizeValue(16));
|
||||
x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
|
||||
EXPECT_EQ("1 2 3 4", x.SummarizeValue(16));
|
||||
EXPECT_EQ("1 2 3...", x.SummarizeValue(3));
|
||||
EXPECT_EQ("[[[1]][[2]]][[[3]][[4]]]", x.SummarizeValue(16));
|
||||
EXPECT_EQ("[[[1]][[2]]][[[3]]]...", x.SummarizeValue(3));
|
||||
x = MkTensor<int>(DT_INT32, TensorShape({0}), {});
|
||||
EXPECT_EQ("", x.SummarizeValue(16));
|
||||
}
|
||||
|
||||
TEST(SummarizeValue, FLOAT) {
|
||||
Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0});
|
||||
EXPECT_EQ("1 2 3 4 0", x.SummarizeValue(16));
|
||||
x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0});
|
||||
EXPECT_EQ("1 2 3 4", x.SummarizeValue(16));
|
||||
EXPECT_EQ("[1 2][3 4]", x.SummarizeValue(16));
|
||||
x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
|
||||
EXPECT_EQ("1 2 3 4", x.SummarizeValue(16));
|
||||
EXPECT_EQ("1 2 3...", x.SummarizeValue(3));
|
||||
EXPECT_EQ("[[[1]][[2]]][[[3]][[4]]]", x.SummarizeValue(16));
|
||||
EXPECT_EQ("[[[1]][[2]]][[[3]]]...", x.SummarizeValue(3));
|
||||
x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {});
|
||||
EXPECT_EQ("", x.SummarizeValue(16));
|
||||
}
|
||||
|
||||
TEST(SummarizeValue, BOOL) {
|
||||
|
@ -37,6 +37,7 @@ std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
|
||||
|
||||
const char* const DEVICE_CPU = "CPU";
|
||||
const char* const DEVICE_GPU = "GPU";
|
||||
const char* const DEVICE_SYCL = "SYCL";
|
||||
|
||||
string DataTypeString(DataType dtype) {
|
||||
if (IsRefType(dtype)) {
|
||||
|
@ -68,8 +68,9 @@ class DeviceType {
|
||||
std::ostream& operator<<(std::ostream& os, const DeviceType& d);
|
||||
|
||||
// Convenient constants that can be passed to a DeviceType constructor
|
||||
extern const char* const DEVICE_CPU; // "CPU"
|
||||
extern const char* const DEVICE_GPU; // "GPU"
|
||||
extern const char* const DEVICE_CPU; // "CPU"
|
||||
extern const char* const DEVICE_GPU; // "GPU"
|
||||
extern const char* const DEVICE_SYCL; // "SYCL"
|
||||
|
||||
typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
|
||||
typedef gtl::ArraySlice<MemoryType> MemoryTypeSlice;
|
||||
|
@ -25,6 +25,7 @@ namespace {
|
||||
TEST(TypesTest, DeviceTypeName) {
|
||||
EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU)));
|
||||
EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU)));
|
||||
EXPECT_EQ("SYCL", DeviceTypeString(DeviceType(DEVICE_SYCL)));
|
||||
}
|
||||
|
||||
TEST(TypesTest, kDataTypeRefOffset) {
|
||||
|
@ -2164,14 +2164,17 @@ tf_kernel_libraries(
|
||||
"reduce_join_op",
|
||||
"string_join_op",
|
||||
"string_split_op",
|
||||
"substr_op",
|
||||
"as_string_op",
|
||||
"base64_ops",
|
||||
],
|
||||
deps = [
|
||||
":bounds_check",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:string_ops_op_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -51,6 +51,17 @@ ConstantOp::~ConstantOp() {}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Const") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
ConstantOp);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_KERNEL(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
@ -18,6 +18,14 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
REGISTER5(UnaryOp, CPU, "Round", functor::round, Eigen::half, float, double,
|
||||
int32, int64);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER(UnaryOp, SYCL, "Round", functor::round, float);
|
||||
namespace functor {
|
||||
DEFINE_UNARY1(round, float);
|
||||
} // namespace functor
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER5(UnaryOp, GPU, "Round", functor::round, Eigen::half, float, double,
|
||||
int32, int64);
|
||||
|
@ -20,6 +20,10 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
|
||||
|
||||
@ -33,6 +37,9 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif
|
||||
|
||||
class BinaryOpShared : public OpKernel {
|
||||
public:
|
||||
@ -96,45 +103,45 @@ class BinaryOp : public BinaryOpShared {
|
||||
if (state.in1_num_elements == 1) {
|
||||
// tensor op scalar
|
||||
functor::BinaryFunctor<Device, Functor, 1>().Right(
|
||||
eigen_device, out_flat, in0.flat<Tin>(), in1.scalar<Tin>(),
|
||||
error_ptr);
|
||||
eigen_device, out_flat, in0.template flat<Tin>(),
|
||||
in1.template scalar<Tin>(), error_ptr);
|
||||
} else if (state.in0_num_elements == 1) {
|
||||
// scalar op tensor
|
||||
functor::BinaryFunctor<Device, Functor, 1>().Left(
|
||||
eigen_device, out_flat, in0.scalar<Tin>(), in1.flat<Tin>(),
|
||||
error_ptr);
|
||||
eigen_device, out_flat, in0.template scalar<Tin>(),
|
||||
in1.template flat<Tin>(), error_ptr);
|
||||
} else {
|
||||
functor::BinaryFunctor<Device, Functor, 1>()(
|
||||
eigen_device, out_flat, in0.flat<Tin>(), in1.flat<Tin>(),
|
||||
error_ptr);
|
||||
eigen_device, out_flat, in0.template flat<Tin>(),
|
||||
in1.template flat<Tin>(), error_ptr);
|
||||
}
|
||||
} else if (ndims == 2) {
|
||||
functor::BinaryFunctor<Device, Functor, 2>().BCast(
|
||||
eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
|
||||
in0.shaped<Tin, 2>(bcast->x_reshape()),
|
||||
in0.template shaped<Tin, 2>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<2>(bcast->x_bcast()),
|
||||
in1.shaped<Tin, 2>(bcast->y_reshape()),
|
||||
in1.template shaped<Tin, 2>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
|
||||
} else if (ndims == 3) {
|
||||
functor::BinaryFunctor<Device, Functor, 3>().BCast(
|
||||
eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
|
||||
in0.shaped<Tin, 3>(bcast->x_reshape()),
|
||||
in0.template shaped<Tin, 3>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<3>(bcast->x_bcast()),
|
||||
in1.shaped<Tin, 3>(bcast->y_reshape()),
|
||||
in1.template shaped<Tin, 3>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
|
||||
} else if (ndims == 4) {
|
||||
functor::BinaryFunctor<Device, Functor, 4>().BCast(
|
||||
eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
|
||||
in0.shaped<Tin, 4>(bcast->x_reshape()),
|
||||
in0.template shaped<Tin, 4>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<4>(bcast->x_bcast()),
|
||||
in1.shaped<Tin, 4>(bcast->y_reshape()),
|
||||
in1.template shaped<Tin, 4>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
|
||||
} else if (ndims == 5) {
|
||||
functor::BinaryFunctor<Device, Functor, 5>().BCast(
|
||||
eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
|
||||
in0.shaped<Tin, 5>(bcast->x_reshape()),
|
||||
in0.template shaped<Tin, 5>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<5>(bcast->x_bcast()),
|
||||
in1.shaped<Tin, 5>(bcast->y_reshape()),
|
||||
in1.template shaped<Tin, 5>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
|
||||
} else {
|
||||
SetUnimplementedError(ctx);
|
||||
|
138
tensorflow/core/kernels/cwise_ops_sycl_common.h
Normal file
138
tensorflow/core/kernels/cwise_ops_sycl_common.h
Normal file
@ -0,0 +1,138 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if !TENSORFLOW_USE_SYCL
|
||||
#error This file must only be included when building TensorFlow with SYCL support
|
||||
#endif
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
|
||||
|
||||
#define EIGEN_USE_SYCL
|
||||
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
|
||||
template <typename OUT, typename RHS>
|
||||
void Assign(const SYCLDevice& d, OUT out, RHS rhs) {
|
||||
out.device(d) = rhs;
|
||||
}
|
||||
|
||||
// Partial specialization of UnaryFunctor<Device=SYCLDevice, Functor>.
|
||||
template <typename Functor>
|
||||
struct UnaryFunctor<SYCLDevice, Functor> {
|
||||
void operator()(const SYCLDevice& d, typename Functor::tout_type out,
|
||||
typename Functor::tin_type in) {
|
||||
To32Bit(out).device(d) = To32Bit(in).unaryExpr(typename Functor::func());
|
||||
}
|
||||
};
|
||||
|
||||
// Partial specialization of BinaryFunctor<Device=SYCLDevice, Functor>.
|
||||
template <typename Functor, int NDIMS, bool has_errors>
|
||||
struct BinaryFunctor<SYCLDevice, Functor, NDIMS, has_errors> {
|
||||
void operator()(const SYCLDevice& d, typename Functor::tout_type out,
|
||||
typename Functor::tin_type in0,
|
||||
typename Functor::tin_type in1, bool* error) {
|
||||
Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
|
||||
}
|
||||
|
||||
void Left(const SYCLDevice& d, typename Functor::tout_type out,
|
||||
typename Functor::tscalar_type scalar,
|
||||
typename Functor::tin_type in, bool* error) {
|
||||
LOG(FATAL) << "BinaryFunctor::Left NOT IMPLEMENTED ! ";
|
||||
}
|
||||
|
||||
void Right(const SYCLDevice& d, typename Functor::tout_type out,
|
||||
typename Functor::tin_type in,
|
||||
typename Functor::tscalar_type scalar, bool* error) {
|
||||
typedef typename Functor::out_type Tout;
|
||||
typedef typename Functor::in_type Tin;
|
||||
typedef typename Functor::func Binary;
|
||||
typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
|
||||
Assign(d, out, in.unaryExpr(Unary(scalar.data())));
|
||||
}
|
||||
|
||||
void BCast(const SYCLDevice& d,
|
||||
typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
|
||||
typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
|
||||
typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
|
||||
bool* error) {
|
||||
LOG(FATAL) << "BinaryFunctor::BCast NOT IMPLEMENTED ";
|
||||
}
|
||||
};
|
||||
|
||||
// Macros to explicitly instantiate kernels on GPU for multiple types
|
||||
// (T0, T1, etc.) for UnaryFunctor (e.g., functor::sqrt).
|
||||
#define DEFINE_UNARY1(F, T) template struct UnaryFunctor<SYCLDevice, F<T> >
|
||||
#define DEFINE_UNARY2(F, T0, T1) \
|
||||
DEFINE_UNARY1(F, T0); \
|
||||
DEFINE_UNARY1(F, T1)
|
||||
#define DEFINE_UNARY3(F, T0, T1, T2) \
|
||||
DEFINE_UNARY2(F, T0, T1); \
|
||||
DEFINE_UNARY1(F, T2)
|
||||
#define DEFINE_UNARY4(F, T0, T1, T2, T3) \
|
||||
DEFINE_UNARY2(F, T0, T1); \
|
||||
DEFINE_UNARY2(F, T2, T3)
|
||||
#define DEFINE_UNARY5(F, T0, T1, T2, T3, T4) \
|
||||
DEFINE_UNARY2(F, T0, T1); \
|
||||
DEFINE_UNARY3(F, T2, T3, T4)
|
||||
|
||||
// Macros to explicitly instantiate kernels on GPU for multiple types
|
||||
// (T0, T1, etc.) for BinaryFunctor.
|
||||
#define DEFINE_BINARY1(F, T) \
|
||||
template struct BinaryFunctor<SYCLDevice, F<T>, 1>; \
|
||||
template struct BinaryFunctor<SYCLDevice, F<T>, 2>; \
|
||||
template struct BinaryFunctor<SYCLDevice, F<T>, 3>
|
||||
#define DEFINE_BINARY2(F, T0, T1) \
|
||||
DEFINE_BINARY1(F, T0); \
|
||||
DEFINE_BINARY1(F, T1)
|
||||
#define DEFINE_BINARY3(F, T0, T1, T2) \
|
||||
DEFINE_BINARY2(F, T0, T1); \
|
||||
DEFINE_BINARY1(F, T2)
|
||||
#define DEFINE_BINARY4(F, T0, T1, T2, T3) \
|
||||
DEFINE_BINARY2(F, T0, T1); \
|
||||
DEFINE_BINARY2(F, T2, T3)
|
||||
#define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \
|
||||
DEFINE_BINARY2(F, T0, T1); \
|
||||
DEFINE_BINARY3(F, T2, T3, T4)
|
||||
#define DEFINE_BINARY6(F, T0, T1, T2, T3, T4, T5) \
|
||||
DEFINE_BINARY3(F, T0, T1, T2); \
|
||||
DEFINE_BINARY3(F, T3, T4, T5)
|
||||
#define DEFINE_BINARY7(F, T0, T1, T2, T3, T4, T5, T6) \
|
||||
DEFINE_BINARY3(F, T0, T1, T2); \
|
||||
DEFINE_BINARY4(F, T3, T4, T5, T6)
|
||||
#define DEFINE_BINARY8(F, T0, T1, T2, T3, T4, T5, T6, T7) \
|
||||
DEFINE_BINARY4(F, T0, T1, T2, T3); \
|
||||
DEFINE_BINARY4(F, T4, T5, T6, T7)
|
||||
#define DEFINE_BINARY9(F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
|
||||
DEFINE_BINARY4(F, T0, T1, T2, T3); \
|
||||
DEFINE_BINARY5(F, T4, T5, T6, T7, T8)
|
||||
#define DEFINE_BINARY10(F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) \
|
||||
DEFINE_BINARY5(F, T0, T1, T2, T3, T4); \
|
||||
DEFINE_BINARY5(F, T5, T6, T7, T8, T9)
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_SYCL_COMMON_H_
|
@ -87,6 +87,29 @@ class RetvalOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
ArgOp);
|
||||
#undef REGISTER
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("input")
|
||||
.TypeConstraint<int32>("T"),
|
||||
RetvalOp);
|
||||
#undef REGISTER
|
||||
#endif
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_Arg").Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
|
||||
|
@ -34,6 +34,24 @@ REGISTER_KERNEL_BUILDER(Name("PlaceholderWithDefault").Device(DEVICE_CPU),
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_SYCL_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RefIdentity").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StopGradient").Device(DEVICE_SYCL).TypeConstraint<type>("T"),\
|
||||
IdentityOp)
|
||||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
|
||||
REGISTER_SYCL_KERNEL(bfloat16);
|
||||
|
||||
#undef REGISTER_SYCL_KERNEL
|
||||
#endif
|
||||
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
@ -50,6 +68,7 @@ REGISTER_GPU_KERNEL(bfloat16);
|
||||
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// A special GPU kernel for int32 and bool.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
|
@ -20,4 +20,8 @@ namespace tensorflow {
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_CPU), NoOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_GPU), NoOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_SYCL), NoOp);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -29,7 +29,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
#ifdef COMPILER_MSVC
|
||||
// msvc does not support unroll. One could try the loop pragma but we need to
|
||||
// take a closer look if this generates better code in this case. For now let
|
||||
// the compiler take care of of it.
|
||||
#define UNROLL
|
||||
#else
|
||||
#define UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -99,6 +106,7 @@ __global__ void __launch_bounds__(1024)
|
||||
Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) /
|
||||
(normMin + sqrtFactor);
|
||||
const T diff = normMax - normMin;
|
||||
const T two = T(2.0);
|
||||
|
||||
// Validate the normalized min and max, because the originals may have been
|
||||
// flipped already.
|
||||
@ -124,7 +132,7 @@ __global__ void __launch_bounds__(1024)
|
||||
z[i] = rand[i] * diff + normMin;
|
||||
}
|
||||
UNROLL for (int i = 0; i < kDistSize; i++) {
|
||||
g[i] = (plusFactor - z[i] * z[i]) / 2.0;
|
||||
g[i] = (plusFactor - z[i] * z[i]) / two;
|
||||
}
|
||||
|
||||
const auto u = dist(&gen);
|
||||
@ -161,7 +169,7 @@ __global__ void __launch_bounds__(1024)
|
||||
UNROLL for (int i = 0; i < kDistSize; i += 2) {
|
||||
const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
|
||||
const T x = normMin < alpha ? alpha - z : normMin - alpha;
|
||||
const T g = Eigen::numext::exp(-x * x / 2.0);
|
||||
const T g = Eigen::numext::exp(-x * x / two);
|
||||
const T u = rand[i + 1];
|
||||
if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) {
|
||||
data[offset] = z * stddev + mean;
|
||||
|
@ -78,6 +78,10 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_SYCL), SendOp);
|
||||
#endif
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp);
|
||||
@ -136,6 +140,10 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp);
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_SYCL), RecvOp);
|
||||
#endif
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp);
|
||||
|
233
tensorflow/core/kernels/substr_op.cc
Normal file
233
tensorflow/core/kernels/substr_op.cc
Normal file
@ -0,0 +1,233 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Position/length can be 32 or 64-bit integers
|
||||
template <typename T>
|
||||
class SubstrOp : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Get inputs
|
||||
const Tensor& input_tensor = context->input(0);
|
||||
const Tensor& pos_tensor = context->input(1);
|
||||
const Tensor& len_tensor = context->input(2);
|
||||
const TensorShape input_shape = input_tensor.shape();
|
||||
const TensorShape pos_shape = pos_tensor.shape();
|
||||
const TensorShape len_shape = len_tensor.shape();
|
||||
|
||||
bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
|
||||
|
||||
if (is_scalar || input_shape == pos_shape) {
|
||||
// pos/len are either scalar or match the shape of input_tensor
|
||||
// Do not need to do broadcasting
|
||||
|
||||
// Reshape input
|
||||
auto input = input_tensor.flat<string>();
|
||||
// Allocate output
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("output", input_tensor.shape(),
|
||||
&output_tensor));
|
||||
auto output = output_tensor->flat<string>();
|
||||
if (is_scalar) {
|
||||
// Perform Op with scalar pos/len
|
||||
const T pos = tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()());
|
||||
const T len = tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
|
||||
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
|
||||
string in = input(i);
|
||||
OP_REQUIRES(context, FastBoundsCheck(pos, in.size()),
|
||||
errors::InvalidArgument("pos ", pos, " out of range for string",
|
||||
"b'", in, "' at index ", i));
|
||||
output(i) = in.substr(pos, len);
|
||||
}
|
||||
} else {
|
||||
// Perform Op element-wise with tensor pos/len
|
||||
auto pos_flat = pos_tensor.flat<T>();
|
||||
auto len_flat = len_tensor.flat<T>();
|
||||
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
|
||||
string in = input(i);
|
||||
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
|
||||
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
|
||||
OP_REQUIRES(context, FastBoundsCheck(pos, in.size()),
|
||||
errors::InvalidArgument("pos ", pos, " out of range for string",
|
||||
"b'", in, "' at index ", i));
|
||||
output(i) = in.substr(pos, len);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Perform op with broadcasting
|
||||
// TODO: Use ternary broadcasting for once available in Eigen. Current
|
||||
// implementation iterates through broadcasted ops element-wise;
|
||||
// this should be parallelized.
|
||||
|
||||
// Create BCast helper with shape of input and pos/len
|
||||
BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape));
|
||||
OP_REQUIRES(context, bcast.IsValid(),
|
||||
errors::InvalidArgument("Incompatible shapes: ",
|
||||
input_shape.DebugString(), " vs. ",
|
||||
pos_shape.DebugString()));
|
||||
TensorShape output_shape = BCast::ToShape(bcast.result_shape());
|
||||
int ndims = output_shape.dims();
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("output", output_shape,
|
||||
&output_tensor));
|
||||
switch (ndims) {
|
||||
case 1: {
|
||||
// Reshape tensors according to BCast results
|
||||
auto input = input_tensor.shaped<string,1>(bcast.x_reshape());
|
||||
auto output = output_tensor->shaped<string,1>(bcast.result_shape());
|
||||
auto pos_shaped = pos_tensor.shaped<T,1>(bcast.y_reshape());
|
||||
auto len_shaped = len_tensor.shaped<T,1>(bcast.y_reshape());
|
||||
|
||||
// Allocate temporary buffer for broadcasted input tensor
|
||||
Tensor input_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DT_STRING,
|
||||
output_shape,
|
||||
&input_buffer));
|
||||
typename TTypes<string,1>::Tensor input_bcast =
|
||||
input_buffer.shaped<string,1>(bcast.result_shape());
|
||||
input_bcast = input.broadcast(
|
||||
BCast::ToIndexArray<1>(bcast.x_bcast()));
|
||||
|
||||
// Allocate temporary buffer for broadcasted position tensor
|
||||
Tensor pos_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<T>::v(),
|
||||
output_shape,
|
||||
&pos_buffer));
|
||||
typename TTypes<T,1>::Tensor pos_bcast = pos_buffer.shaped<T,1>(
|
||||
bcast.result_shape());
|
||||
pos_bcast = pos_shaped.broadcast(
|
||||
BCast::ToIndexArray<1>(bcast.y_bcast()));
|
||||
|
||||
// Allocate temporary buffer for broadcasted length tensor
|
||||
Tensor len_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<T>::v(),
|
||||
output_shape,
|
||||
&len_buffer));
|
||||
typename TTypes<T,1>::Tensor len_bcast = len_buffer.shaped<T,1>(
|
||||
bcast.result_shape());
|
||||
len_bcast = len_shaped.broadcast(
|
||||
BCast::ToIndexArray<1>(bcast.y_bcast()));
|
||||
|
||||
// Iterate through broadcasted tensors and perform substr
|
||||
for (int i = 0; i < output_shape.dim_size(0); ++i) {
|
||||
string in = input_bcast(i);
|
||||
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
|
||||
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
|
||||
OP_REQUIRES(context, FastBoundsCheck(pos, input_bcast(i).size()),
|
||||
errors::InvalidArgument("pos ", pos, " out of range for string",
|
||||
"b'", in, "' at index ", i));
|
||||
output(i) = in.substr(pos, len);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
// Reshape tensors according to BCast results
|
||||
auto input = input_tensor.shaped<string,2>(bcast.x_reshape());
|
||||
auto output = output_tensor->shaped<string,2>(bcast.result_shape());
|
||||
auto pos_shaped = pos_tensor.shaped<T,2>(bcast.y_reshape());
|
||||
auto len_shaped = len_tensor.shaped<T,2>(bcast.y_reshape());
|
||||
|
||||
// Allocate temporary buffer for broadcasted input tensor
|
||||
Tensor input_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DT_STRING,
|
||||
output_shape,
|
||||
&input_buffer));
|
||||
typename TTypes<string,2>::Tensor input_bcast =
|
||||
input_buffer.shaped<string,2>(bcast.result_shape());
|
||||
input_bcast = input.broadcast(
|
||||
BCast::ToIndexArray<2>(bcast.x_bcast()));
|
||||
|
||||
// Allocate temporary buffer for broadcasted position tensor
|
||||
Tensor pos_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<T>::v(),
|
||||
output_shape,
|
||||
&pos_buffer));
|
||||
typename TTypes<T,2>::Tensor pos_bcast = pos_buffer.shaped<T,2>(
|
||||
bcast.result_shape());
|
||||
pos_bcast = pos_shaped.broadcast(
|
||||
BCast::ToIndexArray<2>(bcast.y_bcast()));
|
||||
|
||||
// Allocate temporary buffer for broadcasted length tensor
|
||||
Tensor len_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<T>::v(),
|
||||
output_shape,
|
||||
&len_buffer));
|
||||
typename TTypes<T,2>::Tensor len_bcast = len_buffer.shaped<T,2>(
|
||||
bcast.result_shape());
|
||||
len_bcast = len_shaped.broadcast(
|
||||
BCast::ToIndexArray<2>(bcast.y_bcast()));
|
||||
|
||||
// Iterate through broadcasted tensors and perform substr
|
||||
for (int i = 0; i < output_shape.dim_size(0); ++i) {
|
||||
for (int j = 0; j < output_shape.dim_size(1); ++j) {
|
||||
string in = input_bcast(i, j);
|
||||
const T pos = tensorflow::internal::SubtleMustCopy(
|
||||
pos_bcast(i, j));
|
||||
const T len = tensorflow::internal::SubtleMustCopy(
|
||||
len_bcast(i, j));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(pos, in.size()),
|
||||
errors::InvalidArgument("pos ", pos, " out of range for ",
|
||||
"string b'", in, "' at index ("
|
||||
, i, ", ", j, ")"));
|
||||
output(i, j) = in.substr(pos, len);
|
||||
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
context->SetStatus(errors::Unimplemented(
|
||||
"Substr broadcast not implemented for ", ndims, " dimensions"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_SUBSTR(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Substr") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T"), \
|
||||
SubstrOp<type>);
|
||||
REGISTER_SUBSTR(int32);
|
||||
REGISTER_SUBSTR(int64);
|
||||
} // namespace tensorflow
|
@ -196,6 +196,7 @@ void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
|
||||
// 0. Parse scheme
|
||||
// Make sure scheme matches [a-zA-Z][0-9a-zA-Z.]*
|
||||
// TODO(keveman): Allow "+" and "-" in the scheme.
|
||||
// Keep URI pattern in tensorboard/backend/server.py updated accordingly
|
||||
if (!strings::Scanner(remaining)
|
||||
.One(strings::Scanner::LETTER)
|
||||
.Many(strings::Scanner::LETTER_DIGIT_DOT)
|
||||
|
@ -281,4 +281,113 @@ input: Base64 strings to decode.
|
||||
output: Decoded strings.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Substr")
|
||||
.Input("input: string")
|
||||
.Input("pos: T")
|
||||
.Input("len: T")
|
||||
.Output("output: string")
|
||||
.Attr("T: {int32, int64}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle pos_shape = c->input(1);
|
||||
ShapeHandle len_shape = c->input(2);
|
||||
ShapeHandle unused;
|
||||
// Check that pos/len have same rank
|
||||
TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused));
|
||||
// Check that dimensions are equal
|
||||
for (int32 i = 0; i < c->Rank(pos_shape); ++i) {
|
||||
DimensionHandle pos_dim = c->Dim(pos_shape, i);
|
||||
DimensionHandle len_dim = c->Dim(len_shape, i);
|
||||
if (c->Value(pos_dim) != c->Value(len_dim)) {
|
||||
return errors::InvalidArgument("pos and len shapes must match: ",
|
||||
c->DebugString(pos_shape), " vs. ",
|
||||
c->DebugString(len_shape));
|
||||
}
|
||||
}
|
||||
// c->input(0) is the ShapeHandle to input strings
|
||||
// BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1).
|
||||
return shape_inference::BroadcastBinaryOpShapeFn(c);
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Return substrings from `Tensor` of strings.
|
||||
|
||||
For each string in the input `Tensor`, creates a substring starting at index
|
||||
`pos` with a total length of `len`.
|
||||
|
||||
If `len` defines a substring that would extend beyond the length of the input
|
||||
string, then as many characters as possible are used.
|
||||
|
||||
If `pos` is negative or specifies a character index larger than any of the input
|
||||
strings, then an `InvalidArgumentError` is thrown.
|
||||
|
||||
`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
|
||||
Op creation.
|
||||
|
||||
*NOTE*: `Substr` supports broadcasting up to two dimensions. More about
|
||||
broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
|
||||
---
|
||||
|
||||
Examples
|
||||
|
||||
Using scalar `pos` and `len`:
|
||||
|
||||
```
|
||||
input = [b'Hello', b'World']
|
||||
position = 1
|
||||
length = 3
|
||||
|
||||
output = [b'ell', b'orl']
|
||||
```
|
||||
|
||||
Using `pos` and `len` with same shape as `input`:
|
||||
|
||||
```
|
||||
input = [[b'ten', b'eleven', b'twelve'],
|
||||
[b'thirteen', b'fourteen', b'fifteen'],
|
||||
[b'sixteen', b'seventeen', b'eighteen']]
|
||||
position = [[1, 2, 3],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3]]
|
||||
length = [[2, 3, 4],
|
||||
[4, 3, 2],
|
||||
[5, 5, 5]]
|
||||
|
||||
output = [[b'en', b'eve', b'lve'],
|
||||
[b'hirt', b'urt', b'te'],
|
||||
[b'ixtee', b'vente', b'hteen']]
|
||||
```
|
||||
|
||||
Broadcasting `pos` and `len` onto `input`:
|
||||
|
||||
```
|
||||
input = [[b'ten', b'eleven', b'twelve'],
|
||||
[b'thirteen', b'fourteen', b'fifteen'],
|
||||
[b'sixteen', b'seventeen', b'eighteen'],
|
||||
[b'nineteen', b'twenty', b'twentyone']]
|
||||
position = [1, 2, 3]
|
||||
length = [1, 2, 3]
|
||||
|
||||
output = [[b'e', b'ev', b'lve'],
|
||||
[b'h', b'ur', b'tee'],
|
||||
[b'i', b've', b'hte'],
|
||||
[b'i', b'en', b'nty']]
|
||||
```
|
||||
|
||||
Broadcasting `input` onto `pos` and `len`:
|
||||
|
||||
```
|
||||
input = b'thirteen'
|
||||
position = [1, 5, 7]
|
||||
length = [3, 2, 1]
|
||||
|
||||
output = [b'hir', b'ee', b'n"]
|
||||
```
|
||||
|
||||
input: Tensor of strings
|
||||
pos: Scalar defining the position of first character in each substring
|
||||
len: Scalar defining the number of characters to include in each substring
|
||||
output: Tensor of substrings
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -10,6 +10,7 @@ exports_files(["LICENSE"])
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
load("@local_config_cuda//cuda:platform.bzl", "cuda_library_path")
|
||||
load("@local_config_sycl//sycl:platform.bzl", "sycl_library_path")
|
||||
|
||||
cc_library(
|
||||
name = "gtest",
|
||||
@ -143,6 +144,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sycl",
|
||||
data = [
|
||||
"@local_config_sycl//sycl:{}".format(sycl_library_path("ComputeCpp")),
|
||||
],
|
||||
linkopts = select({
|
||||
"//conditions:default": [
|
||||
"-Wl,-rpath,../local_config_sycl/sycl/lib",
|
||||
],
|
||||
}),
|
||||
deps = [
|
||||
"@local_config_sycl//sycl:syclrt",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "android_srcs",
|
||||
srcs = glob(["*.h"]),
|
||||
|
@ -43,12 +43,18 @@ cc_library(
|
||||
# http://hadoop.apache.org/releases.html
|
||||
# 3. Extract the Hadoop distribution and run:
|
||||
# source libexec/hadoop-config.sh
|
||||
# 4. bazel test \
|
||||
# 4. Optionally set up HDFS cluster configurations (optionally Kerberos) within
|
||||
# $HADOOP_HDFS_HOME/etc/hadoop if you want to test against real
|
||||
# distributed HDFS cluster
|
||||
# 5. bazel test \
|
||||
# --test_env=LD_LIBRARY_PATH=$JAVA_HOME/jre/lib/amd64/server \
|
||||
# --test_env=HADOOP_HDFS_HOME=$HADOOP_HDFS_HOME \
|
||||
# --test_env=CLASSPATH=$($HADOOP_HDFS_HOME/bin/hadoop classpath --glob) \
|
||||
# --test_strategy=local \
|
||||
# :hadoop_file_system_test
|
||||
# To test against the real distributed cluster, add the following option for
|
||||
# bazel test:
|
||||
# --test_env=HADOOP_TEST_TMPDIR=hdfs://cluster/test/tmp/dir
|
||||
tf_cc_test(
|
||||
name = "hadoop_file_system_test",
|
||||
size = "small",
|
||||
|
@ -56,6 +56,8 @@ class LibHDFS {
|
||||
std::function<hdfsFS(hdfsBuilder*)> hdfsBuilderConnect;
|
||||
std::function<hdfsBuilder*()> hdfsNewBuilder;
|
||||
std::function<void(hdfsBuilder*, const char*)> hdfsBuilderSetNameNode;
|
||||
std::function<void(hdfsBuilder*, const char *kerbTicketCachePath)>
|
||||
hdfsBuilderSetKerbTicketCachePath;
|
||||
std::function<int(hdfsFS, hdfsFile)> hdfsCloseFile;
|
||||
std::function<tSize(hdfsFS, hdfsFile, tOffset, void*, tSize)> hdfsPread;
|
||||
std::function<tSize(hdfsFS, hdfsFile, const void*, tSize)> hdfsWrite;
|
||||
@ -81,6 +83,7 @@ class LibHDFS {
|
||||
BIND_HDFS_FUNC(hdfsBuilderConnect);
|
||||
BIND_HDFS_FUNC(hdfsNewBuilder);
|
||||
BIND_HDFS_FUNC(hdfsBuilderSetNameNode);
|
||||
BIND_HDFS_FUNC(hdfsBuilderSetKerbTicketCachePath);
|
||||
BIND_HDFS_FUNC(hdfsCloseFile);
|
||||
BIND_HDFS_FUNC(hdfsPread);
|
||||
BIND_HDFS_FUNC(hdfsWrite);
|
||||
@ -135,6 +138,10 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
|
||||
} else {
|
||||
hdfs_->hdfsBuilderSetNameNode(builder, nn.c_str());
|
||||
}
|
||||
char* ticket_cache_path = getenv("KERB_TICKET_CACHE_PATH");
|
||||
if (ticket_cache_path != nullptr) {
|
||||
hdfs_->hdfsBuilderSetKerbTicketCachePath(builder, ticket_cache_path);
|
||||
}
|
||||
*fs = hdfs_->hdfsBuilderConnect(builder);
|
||||
if (*fs == nullptr) {
|
||||
return errors::NotFound(strerror(errno));
|
||||
@ -360,9 +367,15 @@ Status HadoopFileSystem::DeleteDir(const string& dir) {
|
||||
hdfsFileInfo* info =
|
||||
hdfs_->hdfsListDirectory(fs, TranslateName(dir).c_str(), &entries);
|
||||
if (info != nullptr) {
|
||||
return IOError(dir, errno);
|
||||
hdfs_->hdfsFreeFileInfo(info, entries);
|
||||
}
|
||||
// Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty
|
||||
// folder, expscially for Kerberos enable setup, EAGAIN is quite common when
|
||||
// the call is actually successful. Check again by Stat.
|
||||
if (info == nullptr && errno != 0) {
|
||||
FileStatistics stat;
|
||||
TF_RETURN_IF_ERROR(Stat(dir, &stat));
|
||||
}
|
||||
hdfs_->hdfsFreeFileInfo(info, entries);
|
||||
|
||||
if (entries > 0) {
|
||||
return errors::FailedPrecondition("Cannot delete a non-empty directory.");
|
||||
|
@ -28,6 +28,15 @@ class HadoopFileSystemTest : public ::testing::Test {
|
||||
protected:
|
||||
HadoopFileSystemTest() {}
|
||||
|
||||
string TmpDir(const string& path) {
|
||||
char* test_dir = getenv("HADOOP_TEST_TMPDIR");
|
||||
if (test_dir != nullptr) {
|
||||
return io::JoinPath(string(test_dir), path);
|
||||
} else {
|
||||
return "file://" + io::JoinPath(testing::TmpDir(), path);
|
||||
}
|
||||
}
|
||||
|
||||
Status WriteString(const string& fname, const string& content) {
|
||||
std::unique_ptr<WritableFile> writer;
|
||||
TF_RETURN_IF_ERROR(hdfs.NewWritableFile(fname, &writer));
|
||||
@ -58,8 +67,7 @@ class HadoopFileSystemTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(HadoopFileSystemTest, RandomAccessFile) {
|
||||
const string fname =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "RandomAccessFile");
|
||||
const string fname = TmpDir("RandomAccessFile");
|
||||
const string content = "abcdefghijklmn";
|
||||
TF_ASSERT_OK(WriteString(fname, content));
|
||||
|
||||
@ -83,8 +91,7 @@ TEST_F(HadoopFileSystemTest, RandomAccessFile) {
|
||||
|
||||
TEST_F(HadoopFileSystemTest, WritableFile) {
|
||||
std::unique_ptr<WritableFile> writer;
|
||||
const string fname =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "WritableFile");
|
||||
const string fname = TmpDir("WritableFile");
|
||||
TF_EXPECT_OK(hdfs.NewWritableFile(fname, &writer));
|
||||
TF_EXPECT_OK(writer->Append("content1,"));
|
||||
TF_EXPECT_OK(writer->Append("content2"));
|
||||
@ -98,16 +105,14 @@ TEST_F(HadoopFileSystemTest, WritableFile) {
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, FileExists) {
|
||||
const string fname =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "FileExists");
|
||||
const string fname = TmpDir("FileExists");
|
||||
EXPECT_EQ(error::Code::NOT_FOUND, hdfs.FileExists(fname).code());
|
||||
TF_ASSERT_OK(WriteString(fname, "test"));
|
||||
TF_EXPECT_OK(hdfs.FileExists(fname));
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, GetChildren) {
|
||||
const string base =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "GetChildren");
|
||||
const string base = TmpDir("GetChildren");
|
||||
TF_EXPECT_OK(hdfs.CreateDir(base));
|
||||
|
||||
const string file = io::JoinPath(base, "testfile.csv");
|
||||
@ -122,16 +127,14 @@ TEST_F(HadoopFileSystemTest, GetChildren) {
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, DeleteFile) {
|
||||
const string fname =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "DeleteFile");
|
||||
const string fname = TmpDir("DeleteFile");
|
||||
EXPECT_FALSE(hdfs.DeleteFile(fname).ok());
|
||||
TF_ASSERT_OK(WriteString(fname, "test"));
|
||||
TF_EXPECT_OK(hdfs.DeleteFile(fname));
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, GetFileSize) {
|
||||
const string fname =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "GetFileSize");
|
||||
const string fname = TmpDir("GetFileSize");
|
||||
TF_ASSERT_OK(WriteString(fname, "test"));
|
||||
uint64 file_size = 0;
|
||||
TF_EXPECT_OK(hdfs.GetFileSize(fname, &file_size));
|
||||
@ -139,8 +142,7 @@ TEST_F(HadoopFileSystemTest, GetFileSize) {
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, CreateDirStat) {
|
||||
const string dir =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "CreateDirStat");
|
||||
const string dir = TmpDir("CreateDirStat");
|
||||
TF_EXPECT_OK(hdfs.CreateDir(dir));
|
||||
FileStatistics stat;
|
||||
TF_EXPECT_OK(hdfs.Stat(dir, &stat));
|
||||
@ -148,7 +150,7 @@ TEST_F(HadoopFileSystemTest, CreateDirStat) {
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, DeleteDir) {
|
||||
const string dir = "file://" + io::JoinPath(testing::TmpDir(), "DeleteDir");
|
||||
const string dir = TmpDir("DeleteDir");
|
||||
EXPECT_FALSE(hdfs.DeleteDir(dir).ok());
|
||||
TF_EXPECT_OK(hdfs.CreateDir(dir));
|
||||
TF_EXPECT_OK(hdfs.DeleteDir(dir));
|
||||
@ -157,10 +159,8 @@ TEST_F(HadoopFileSystemTest, DeleteDir) {
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, RenameFile) {
|
||||
const string fname1 =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "RenameFile1");
|
||||
const string fname2 =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "RenameFile2");
|
||||
const string fname1 = TmpDir("RenameFile1");
|
||||
const string fname2 = TmpDir("RenameFile2");
|
||||
TF_ASSERT_OK(WriteString(fname1, "test"));
|
||||
TF_EXPECT_OK(hdfs.RenameFile(fname1, fname2));
|
||||
string content;
|
||||
@ -169,10 +169,8 @@ TEST_F(HadoopFileSystemTest, RenameFile) {
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, RenameFile_Overwrite) {
|
||||
const string fname1 =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "RenameFile1");
|
||||
const string fname2 =
|
||||
"file://" + io::JoinPath(testing::TmpDir(), "RenameFile2");
|
||||
const string fname1 = TmpDir("RenameFile1");
|
||||
const string fname2 = TmpDir("RenameFile2");
|
||||
|
||||
TF_ASSERT_OK(WriteString(fname2, "test"));
|
||||
TF_EXPECT_OK(hdfs.FileExists(fname2));
|
||||
@ -185,7 +183,7 @@ TEST_F(HadoopFileSystemTest, RenameFile_Overwrite) {
|
||||
}
|
||||
|
||||
TEST_F(HadoopFileSystemTest, StatFile) {
|
||||
const string fname = "file://" + io::JoinPath(testing::TmpDir(), "StatFile");
|
||||
const string fname = TmpDir("StatFile");
|
||||
TF_ASSERT_OK(WriteString(fname, "test"));
|
||||
FileStatistics stat;
|
||||
TF_EXPECT_OK(hdfs.Stat(fname, &stat));
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#ifdef _MSC_VER
|
||||
// the following avx intrinsics are not defined on windows
|
||||
// in immintrin.h so we define them here.
|
||||
//
|
||||
//
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#define _mm_load_pd1 _mm_load1_pd
|
||||
|
@ -142,6 +142,7 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
|
||||
progress = true;
|
||||
}
|
||||
|
||||
// Handle legacy naming convention for cpu and gpu.
|
||||
if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/CPU:")) {
|
||||
p->has_type = true;
|
||||
|
@ -180,6 +180,10 @@ def train_and_eval():
|
||||
skiprows=1,
|
||||
engine="python")
|
||||
|
||||
# remove NaN elements
|
||||
df_train = df_train.dropna(how='any', axis=0)
|
||||
df_test = df_test.dropna(how='any', axis=0)
|
||||
|
||||
df_train[LABEL_COLUMN] = (
|
||||
df_train["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
|
||||
df_test[LABEL_COLUMN] = (
|
||||
|
@ -3,7 +3,7 @@
|
||||
Generates values in an interval.
|
||||
|
||||
A sequence of `num` evenly-spaced values are generated beginning at `start`.
|
||||
If `num > 1`, the values in the sequence increase by `stop - start / num - 1`,
|
||||
If `num > 1`, the values in the sequence increase by `(stop - start) / (num - 1)`,
|
||||
so that the last one is exactly `stop`.
|
||||
|
||||
For example:
|
||||
|
@ -11,8 +11,8 @@ the full softmax loss.
|
||||
At inference time, you can compute full softmax probabilities with the
|
||||
expression `tf.nn.softmax(tf.matmul(inputs, tf.transpose(weights)) + biases)`.
|
||||
|
||||
See our [Candidate Sampling Algorithms Reference]
|
||||
(../../extras/candidate_sampling.pdf)
|
||||
See our
|
||||
[Candidate Sampling Algorithms Reference](../../extras/candidate_sampling.pdf)
|
||||
|
||||
Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
|
||||
([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
|
||||
|
@ -17,7 +17,7 @@ for k in 0..in_channels-1
|
||||
filter[di, dj, k, q]
|
||||
|
||||
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
|
||||
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
|
||||
horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
@ -42,8 +42,7 @@ with an otherwise unused class.
|
||||
where a sampled class equals one of the target classes. If set to
|
||||
`True`, this is a "Sampled Logistic" loss instead of NCE, and we are
|
||||
learning to generate log-odds instead of log probabilities. See
|
||||
our [Candidate Sampling Algorithms Reference]
|
||||
(../../extras/candidate_sampling.pdf).
|
||||
our [Candidate Sampling Algorithms Reference](../../extras/candidate_sampling.pdf).
|
||||
Default is False.
|
||||
* <b>`partition_strategy`</b>: A string specifying the partitioning strategy, relevant
|
||||
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
|
||||
|
@ -11,8 +11,8 @@ each component is divided by the weighted, squared sum of inputs within
|
||||
sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
|
||||
output = input / (bias + alpha * sqr_sum) ** beta
|
||||
|
||||
For details, see [Krizhevsky et al., ImageNet classification with deep
|
||||
convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks).
|
||||
For details, see
|
||||
[Krizhevsky et al., ImageNet classification with deep convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks).
|
||||
|
||||
##### Args:
|
||||
|
||||
|
@ -22,7 +22,7 @@ In detail, with the default NHWC format,
|
||||
filter[di, dj, q, k]
|
||||
|
||||
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
|
||||
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
|
||||
horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
@ -207,7 +207,7 @@ Here are some of the typical usage models:
|
||||
sess.run(logits)
|
||||
# Creates a saver.
|
||||
saver0 = tf.train.Saver()
|
||||
saver0.save(sess, saver0_ckpt)
|
||||
saver0.save(sess, 'my-save-dir/my-model-10000')
|
||||
# Generates MetaGraphDef.
|
||||
saver0.export_meta_graph('my-save-dir/my-model-10000.meta')
|
||||
```
|
||||
|
@ -39,6 +39,7 @@ The TensorFlow community has created many great projects around TensorFlow, incl
|
||||
* [Caffe to TensorFlow model converter](https://github.com/ethereon/caffe-tensorflow)
|
||||
* [Bitfusion's` GPU-enabled AWS EC2 TensorFlow AMI](https://github.com/bitfusionio/amis/tree/master/awsmrkt-bfboost-ubuntu14-cuda75-tensorflow) ([Launch AMI](https://aws.amazon.com/marketplace/pp/B01EYKBEQ0))
|
||||
* [Rust language bindings](https://github.com/google/tensorflow-rust)
|
||||
* [Operator Vectorization Library](https://github.com/opveclib/opveclib)
|
||||
|
||||
### Development
|
||||
|
||||
|
@ -246,7 +246,7 @@ Filling queue with 20000 CIFAR images before starting to train. This will take a
|
||||
...
|
||||
```
|
||||
|
||||
The script reports the total loss every 10 steps as well the speed at which
|
||||
The script reports the total loss every 10 steps as well as the speed at which
|
||||
the last batch of data was processed. A few comments:
|
||||
|
||||
* The first batch of data can be inordinately slow (e.g. several minutes) as the
|
||||
|
@ -79,6 +79,8 @@ from tensorflow.python.client.client_lib import *
|
||||
# Ops
|
||||
from tensorflow.python.ops.standard_ops import *
|
||||
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
# Bring in subpackages.
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import resources
|
||||
|
@ -357,6 +357,13 @@ tf_py_test(
|
||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "substr_op_test",
|
||||
size = "small",
|
||||
srcs = ["substr_op_test.py"],
|
||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "summary_ops_test",
|
||||
size = "small",
|
||||
|
235
tensorflow/python/kernel_tests/substr_op_test.py
Normal file
235
tensorflow/python/kernel_tests/substr_op_test.py
Normal file
@ -0,0 +1,235 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for Substr op from string_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class SubstrOpTest(tf.test.TestCase):
|
||||
|
||||
def _testScalarString(self, dtype):
|
||||
test_string = b"Hello"
|
||||
position = np.array(1, dtype)
|
||||
length = np.array(3, dtype)
|
||||
expected_value = b"ell"
|
||||
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
substr = substr_op.eval()
|
||||
self.assertAllEqual(substr, expected_value)
|
||||
|
||||
def _testVectorStrings(self, dtype):
|
||||
test_string = [b"Hello", b"World"]
|
||||
position = np.array(1, dtype)
|
||||
length = np.array(3, dtype)
|
||||
expected_value = [b"ell", b"orl"]
|
||||
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
substr = substr_op.eval()
|
||||
self.assertAllEqual(substr, expected_value)
|
||||
|
||||
def _testMatrixStrings(self, dtype):
|
||||
test_string = [[b"ten", b"eleven", b"twelve"],
|
||||
[b"thirteen", b"fourteen", b"fifteen"],
|
||||
[b"sixteen", b"seventeen", b"eighteen"]]
|
||||
position = np.array(1, dtype)
|
||||
length = np.array(4, dtype)
|
||||
expected_value = [[b"en", b"leve", b"welv"],
|
||||
[b"hirt", b"ourt", b"ifte"],
|
||||
[b"ixte", b"even", b"ight"]]
|
||||
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
substr = substr_op.eval()
|
||||
self.assertAllEqual(substr, expected_value)
|
||||
|
||||
def _testElementWisePosLen(self, dtype):
|
||||
test_string = [[b"ten", b"eleven", b"twelve"],
|
||||
[b"thirteen", b"fourteen", b"fifteen"],
|
||||
[b"sixteen", b"seventeen", b"eighteen"]]
|
||||
position = np.array([[1, 2, 3],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3]], dtype)
|
||||
length = np.array([[2, 3, 4],
|
||||
[4, 3, 2],
|
||||
[5, 5, 5]], dtype)
|
||||
expected_value = [[b"en", b"eve", b"lve"],
|
||||
[b"hirt", b"urt", b"te"],
|
||||
[b"ixtee", b"vente", b"hteen"]]
|
||||
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
substr = substr_op.eval()
|
||||
self.assertAllEqual(substr, expected_value)
|
||||
|
||||
def _testBroadcast(self, dtype):
|
||||
# Broadcast pos/len onto input string
|
||||
test_string = [[b"ten", b"eleven", b"twelve"],
|
||||
[b"thirteen", b"fourteen", b"fifteen"],
|
||||
[b"sixteen", b"seventeen", b"eighteen"],
|
||||
[b"nineteen", b"twenty", b"twentyone"]]
|
||||
position = np.array([1, 2, 3], dtype)
|
||||
length = np.array([1, 2, 3], dtype)
|
||||
expected_value = [[b"e", b"ev", b"lve"],
|
||||
[b"h", b"ur", b"tee"],
|
||||
[b"i", b"ve", b"hte"],
|
||||
[b"i", b"en", b"nty"]]
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
substr = substr_op.eval()
|
||||
self.assertAllEqual(substr, expected_value)
|
||||
|
||||
# Broadcast input string onto pos/len
|
||||
test_string = [b"thirteen", b"fourteen", b"fifteen"]
|
||||
position = np.array([[1, 2, 3],
|
||||
[3, 2, 1],
|
||||
[5, 5, 5]], dtype)
|
||||
length = np.array([[3, 2, 1],
|
||||
[1, 2, 3],
|
||||
[2, 2, 2]], dtype)
|
||||
expected_value = [[b"hir", b"ur", b"t"],
|
||||
[b"r", b"ur", b"ift"],
|
||||
[b"ee", b"ee", b"en"]]
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
substr = substr_op.eval()
|
||||
self.assertAllEqual(substr, expected_value)
|
||||
|
||||
# Test 1D broadcast
|
||||
test_string = b"thirteen"
|
||||
position = np.array([1, 5, 7], dtype)
|
||||
length = np.array([3, 2, 1], dtype)
|
||||
expected_value = [b"hir", b"ee", b"n"]
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
substr = substr_op.eval()
|
||||
self.assertAllEqual(substr, expected_value)
|
||||
|
||||
def _testBadBroadcast(self, dtype):
|
||||
test_string = [[b"ten", b"eleven", b"twelve"],
|
||||
[b"thirteen", b"fourteen", b"fifteen"],
|
||||
[b"sixteen", b"seventeen", b"eighteen"]]
|
||||
position = np.array([1, 2, 3, 4], dtype)
|
||||
length = np.array([1, 2, 3, 4], dtype)
|
||||
expected_value = [[b"e", b"ev", b"lve"],
|
||||
[b"h", b"ur", b"tee"],
|
||||
[b"i", b"ve", b"hte"]]
|
||||
with self.assertRaises(ValueError):
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
|
||||
def _testOutOfRangeError(self, dtype):
|
||||
# Scalar/Scalar
|
||||
test_string = b"Hello"
|
||||
position = np.array(7, dtype)
|
||||
length = np.array(3, dtype)
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
substr = substr_op.eval()
|
||||
|
||||
# Vector/Scalar
|
||||
test_string = [b"good", b"good", b"bad", b"good"]
|
||||
position = np.array(3, dtype)
|
||||
length = np.array(1, dtype)
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
substr = substr_op.eval()
|
||||
|
||||
# Negative pos
|
||||
test_string = b"Hello"
|
||||
position = np.array(-1, dtype)
|
||||
length = np.array(3, dtype)
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
substr = substr_op.eval()
|
||||
|
||||
# Matrix/Matrix
|
||||
test_string = [[b"good", b"good", b"good"],
|
||||
[b"good", b"good", b"bad"],
|
||||
[b"good", b"good", b"good"]]
|
||||
position = np.array([[1, 2, 3],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3]], dtype)
|
||||
length = np.array([[3, 2, 1],
|
||||
[1, 2, 3],
|
||||
[2, 2, 2]], dtype)
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
substr = substr_op.eval()
|
||||
|
||||
# Broadcast
|
||||
test_string = [[b"good", b"good", b"good"],
|
||||
[b"good", b"good", b"bad"]]
|
||||
position = np.array([1, 2, 3], dtype)
|
||||
length = np.array([1, 2, 3], dtype)
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
with self.test_session():
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
substr = substr_op.eval()
|
||||
|
||||
def _testMismatchPosLenShapes(self, dtype):
|
||||
test_string = [[b"ten", b"eleven", b"twelve"],
|
||||
[b"thirteen", b"fourteen", b"fifteen"],
|
||||
[b"sixteen", b"seventeen", b"eighteen"]]
|
||||
position = np.array([[1, 2, 3]], dtype)
|
||||
length = np.array([2, 3, 4], dtype)
|
||||
# Should fail: position/length have different rank
|
||||
with self.assertRaises(ValueError):
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
|
||||
position = np.array([[1, 2, 3],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3]], dtype)
|
||||
length = np.array([[2, 3, 4]], dtype)
|
||||
# Should fail: postion/length have different dimensionality
|
||||
with self.assertRaises(ValueError):
|
||||
substr_op = tf.substr(test_string, position, length)
|
||||
|
||||
def _testAll(self, dtype):
|
||||
self._testScalarString(dtype)
|
||||
self._testVectorStrings(dtype)
|
||||
self._testMatrixStrings(dtype)
|
||||
self._testElementWisePosLen(dtype)
|
||||
self._testBroadcast(dtype)
|
||||
self._testBadBroadcast(dtype)
|
||||
self._testOutOfRangeError(dtype)
|
||||
self._testMismatchPosLenShapes(dtype)
|
||||
|
||||
def testInt32(self):
|
||||
self._testAll(np.int32)
|
||||
|
||||
def testInt64(self):
|
||||
self._testAll(np.int64)
|
||||
|
||||
def testWrongDtype(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(TypeError):
|
||||
tf.substr(b"test", 3.0, 1)
|
||||
with self.assertRaises(TypeError):
|
||||
tf.substr(b"test", 3, 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -1407,6 +1407,8 @@ def reduce_logsumexp(input_tensor, reduction_indices=None, keep_dims=False,
|
||||
reduction_indices,
|
||||
keep_dims=True)) + my_max
|
||||
if not keep_dims:
|
||||
if isinstance(reduction_indices, int):
|
||||
reduction_indices = [reduction_indices]
|
||||
result = array_ops.squeeze(result, reduction_indices)
|
||||
return result
|
||||
|
||||
|
@ -67,6 +67,16 @@ class LogSumExpTest(test_util.TensorFlowTestCase):
|
||||
self.assertShapeEqual(y_np, y_tf)
|
||||
y_tf_np = y_tf.eval()
|
||||
self.assertAllClose(y_tf_np, y_np)
|
||||
|
||||
def testReductionIndices2(self):
|
||||
for dtype in [np.float16, np.float32, np.double]:
|
||||
x_np = np.random.rand(5, 5).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
y_tf = math_ops.reduce_logsumexp(x_np, reduction_indices=0)
|
||||
y_np = log(np.sum(exp(x_np), axis=0))
|
||||
self.assertShapeEqual(y_np, y_tf)
|
||||
y_tf_np = y_tf.eval()
|
||||
self.assertAllClose(y_tf_np, y_np)
|
||||
|
||||
def testKeepDims(self):
|
||||
for dtype in [np.float16, np.float32, np.double]:
|
||||
|
@ -454,7 +454,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
|
||||
relu_logits = math_ops.select(cond, logits, zeros)
|
||||
neg_abs_logits = math_ops.select(cond, -logits, logits)
|
||||
return math_ops.add(relu_logits - logits * targets,
|
||||
math_ops.log(1 + math_ops.exp(neg_abs_logits)),
|
||||
math_ops.log1p(math_ops.exp(neg_abs_logits)),
|
||||
name=name)
|
||||
|
||||
|
||||
@ -522,7 +522,7 @@ def weighted_cross_entropy_with_logits(logits, targets, pos_weight, name=None):
|
||||
log_weight = 1 + (pos_weight - 1) * targets
|
||||
return math_ops.add(
|
||||
(1 - targets) * logits,
|
||||
log_weight * (math_ops.log(1 + math_ops.exp(-math_ops.abs(logits))) +
|
||||
log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
|
||||
nn_ops.relu(-logits)),
|
||||
name=name)
|
||||
|
||||
|
@ -33,6 +33,7 @@ string tensor.
|
||||
## Splitting
|
||||
|
||||
@@string_split
|
||||
@@substr
|
||||
|
||||
## Conversion
|
||||
|
||||
@ -138,3 +139,4 @@ def _ReduceJoinShape(op):
|
||||
|
||||
ops.RegisterShape("StringJoin")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("StringSplit")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Substr")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_state_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
@ -316,9 +317,14 @@ class Variable(object):
|
||||
if init_from_fn:
|
||||
expected_shape_list = full_shape_to_list(expected_shape)
|
||||
set_shape = validate_shape and expected_shape.is_fully_defined()
|
||||
self._variable = state_ops.variable_op(
|
||||
expected_shape_list, dtype.base_dtype, set_shape=set_shape,
|
||||
name=name)
|
||||
self._variable = gen_state_ops._variable(
|
||||
shape=expected_shape_list,
|
||||
dtype=dtype.base_dtype,
|
||||
name=name,
|
||||
container="",
|
||||
shared_name="")
|
||||
if set_shape:
|
||||
self._variable.set_shape(expected_shape_list)
|
||||
with ops.colocate_with(self._variable.op):
|
||||
with ops.name_scope("Initializer"):
|
||||
# Colocate the tensors created by the initial_value() function
|
||||
@ -336,12 +342,15 @@ class Variable(object):
|
||||
and self._initial_value.get_shape().is_fully_defined())
|
||||
# In this case, the variable op can't be created until after the
|
||||
# initial_value has been converted to a Tensor with a known type.
|
||||
self._variable = state_ops.variable_op(
|
||||
full_shape_to_list(self._initial_value.get_shape()),
|
||||
self._initial_value.dtype.base_dtype,
|
||||
set_shape=set_shape,
|
||||
name=name)
|
||||
|
||||
self._variable = gen_state_ops._variable(
|
||||
shape=full_shape_to_list(self._initial_value.get_shape()),
|
||||
dtype=self._initial_value.dtype.base_dtype,
|
||||
name=name,
|
||||
container="",
|
||||
shared_name="")
|
||||
if set_shape:
|
||||
self._variable.set_shape(
|
||||
full_shape_to_list(self._initial_value.get_shape()))
|
||||
# Manually overrides the variable's shape with the initial value's.
|
||||
if validate_shape:
|
||||
initial_value_shape = self._initial_value.get_shape()
|
||||
|
@ -205,6 +205,7 @@ from tensorflow.python.training.queue_runner import *
|
||||
# For the module level doc.
|
||||
from tensorflow.python.training import input as _input
|
||||
from tensorflow.python.training.input import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook
|
||||
@ -246,7 +247,7 @@ from tensorflow.python.training.training_util import assert_global_step
|
||||
from tensorflow.python.pywrap_tensorflow import do_quantize_training_on_graphdef
|
||||
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
|
||||
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
# Training data protos.
|
||||
from tensorflow.core.example.example_pb2 import *
|
||||
from tensorflow.core.example.feature_pb2 import *
|
||||
@ -254,7 +255,7 @@ from tensorflow.core.protobuf.saver_pb2 import *
|
||||
|
||||
# Utility op. Open Source. TODO(touts): move to nn?
|
||||
from tensorflow.python.training.learning_rate_decay import *
|
||||
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
# Distributed computing support.
|
||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef
|
||||
@ -263,7 +264,6 @@ from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
from tensorflow.python.training.server_lib import Server
|
||||
|
||||
|
||||
# Symbols whitelisted for export without documentation.
|
||||
_allowed_symbols = [
|
||||
# TODO(cwhipkey): review these and move to contrib or expose through
|
||||
|
@ -25,6 +25,7 @@ import functools
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import re
|
||||
|
||||
import six
|
||||
from six.moves import BaseHTTPServer
|
||||
@ -67,21 +68,20 @@ def ParseEventFilesSpec(logdir):
|
||||
files = {}
|
||||
if logdir is None:
|
||||
return files
|
||||
# Make sure keeping consistent with ParseURI in core/lib/io/path.cc
|
||||
uri_pattern = re.compile("[a-zA-Z][0-9a-zA-Z.]://.*")
|
||||
for specification in logdir.split(','):
|
||||
# If it's a gcs or hdfs path, don't split on colon
|
||||
if (io_wrapper.IsGCSPath(specification) or
|
||||
specification.startswith('hdfs://')):
|
||||
run_name = None
|
||||
path = specification
|
||||
# If the spec looks like /foo:bar/baz, then we assume it's a path with a
|
||||
# colon.
|
||||
elif ':' in specification and specification[0] != '/':
|
||||
# Check if the spec contains group. A spec start with xyz:// is regarded as
|
||||
# URI path spec instead of group spec. If the spec looks like /foo:bar/baz,
|
||||
# then we assume it's a path with a colon.
|
||||
if uri_pattern.match(specification) is None and \
|
||||
':' in specification and specification[0] != '/':
|
||||
# We split at most once so run_name:/path:with/a/colon will work.
|
||||
run_name, _, path = specification.partition(':')
|
||||
else:
|
||||
run_name = None
|
||||
path = specification
|
||||
if not (io_wrapper.IsGCSPath(path) or path.startswith('hdfs://')):
|
||||
if uri_pattern.match(path) is None:
|
||||
path = os.path.realpath(path)
|
||||
files[path] = run_name
|
||||
return files
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user