cleans up warning/errors tensorflow/stream_executor (#2555)

This commit is contained in:
Fabrizio Milo 2016-06-03 10:29:58 -07:00 committed by Vijay Vasudevan
parent d42facc3cc
commit 4c789e39be
14 changed files with 39 additions and 26 deletions

View File

@ -7,7 +7,7 @@ The contrib directory contains project directories, each of which has designated
owners. It is meant to contain features and contributions that eventually should owners. It is meant to contain features and contributions that eventually should
get merged into core TensorFlow, but whose interfaces may still change, or which get merged into core TensorFlow, but whose interfaces may still change, or which
require some testing to see whether they can find broader acceptance. We are require some testing to see whether they can find broader acceptance. We are
trying to keep dupliction within contrib to a minimum, so you may be asked to trying to keep duplication within contrib to a minimum, so you may be asked to
refactor code in contrib to use some feature inside core or in another project refactor code in contrib to use some feature inside core or in another project
in contrib rather than reimplementing the feature. in contrib rather than reimplementing the feature.

View File

@ -16,6 +16,8 @@ limitations under the License.
#define USE_EIGEN_TENSOR #define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include <array>
#include "tensorflow/core/kernels/cudnn_pooling_gpu.h" #include "tensorflow/core/kernels/cudnn_pooling_gpu.h"
#include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/conv_3d.h" #include "tensorflow/core/kernels/conv_3d.h"

View File

@ -126,7 +126,8 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
state = initial_state state = initial_state
else: else:
if not dtype: if not dtype:
raise ValueError("If no initial_state is provided, dtype must be.") raise ValueError("If no initial_state is provided, "
"dtype must be specified")
state = cell.zero_state(batch_size, dtype) state = cell.zero_state(batch_size, dtype)
if sequence_length is not None: # Prepare variables if sequence_length is not None: # Prepare variables

View File

@ -27,8 +27,7 @@ CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec);
CUDAExecutor *ExtractCudaExecutor(StreamExecutor *stream_exec); CUDAExecutor *ExtractCudaExecutor(StreamExecutor *stream_exec);
ScopedActivateExecutorContext::ScopedActivateExecutorContext( ScopedActivateExecutorContext::ScopedActivateExecutorContext(
CUDAExecutor *cuda_exec) CUDAExecutor *cuda_exec):
: cuda_exec_(cuda_exec),
driver_scoped_activate_context_( driver_scoped_activate_context_(
new ScopedActivateContext{ExtractCudaContext(cuda_exec)}) { } new ScopedActivateContext{ExtractCudaContext(cuda_exec)}) { }

View File

@ -51,8 +51,6 @@ class ScopedActivateExecutorContext {
~ScopedActivateExecutorContext(); ~ScopedActivateExecutorContext();
private: private:
// The CUDA executor implementation whose context is activated.
CUDAExecutor* cuda_exec_;
// The cuda.h-using datatype that we wrap. // The cuda.h-using datatype that we wrap.
ScopedActivateContext* driver_scoped_activate_context_; ScopedActivateContext* driver_scoped_activate_context_;

View File

@ -457,6 +457,7 @@ class ScopedFilterDescriptor {
<< ToString(status); << ToString(status);
} }
#if CUDNN_VERSION >= 5000
// TODO(b/23032134): Even if the filter layout is not supported, // TODO(b/23032134): Even if the filter layout is not supported,
// cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it
// does not take layout as an input. Maybe force cuDNN by giving wrong // does not take layout as an input. Maybe force cuDNN by giving wrong
@ -471,6 +472,7 @@ class ScopedFilterDescriptor {
<< FilterLayoutString(filter_descriptor.layout()); << FilterLayoutString(filter_descriptor.layout());
break; break;
} }
#endif
std::vector<int> dims(2 + filter_descriptor.ndims()); std::vector<int> dims(2 + filter_descriptor.ndims());
dims[0] = filter_descriptor.output_feature_map_count(); dims[0] = filter_descriptor.output_feature_map_count();
@ -666,7 +668,7 @@ class ScopedActivationDescriptor {
mode = CUDNN_ACTIVATION_TANH; mode = CUDNN_ACTIVATION_TANH;
break; break;
default: default:
LOG(ERROR) << "unrecognized activation mode: " LOG(FATAL) << "unrecognized activation mode: "
<< static_cast<int>(activation_mode); << static_cast<int>(activation_mode);
} }
@ -1916,6 +1918,7 @@ bool CudnnSupport::DoNormalize(
Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) { const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
LOG(FATAL) << "not yet implemented"; // TODO(leary) LOG(FATAL) << "not yet implemented"; // TODO(leary)
return false;
} }
bool CudnnSupport::DoDepthConcatenate( bool CudnnSupport::DoDepthConcatenate(
@ -1977,6 +1980,7 @@ bool CudnnSupport::DoElementwiseOperate(
const dnn::BatchDescriptor& output_dimensions, const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) { DeviceMemory<float>* output_data) {
LOG(FATAL) << "not yet implemented"; // TODO(leary) LOG(FATAL) << "not yet implemented"; // TODO(leary)
return false;
} }
bool CudnnSupport::DoXYPad(Stream* stream, bool CudnnSupport::DoXYPad(Stream* stream,
@ -1985,6 +1989,7 @@ bool CudnnSupport::DoXYPad(Stream* stream,
int64 left_pad, int64 right_pad, int64 top_pad, int64 left_pad, int64 right_pad, int64 top_pad,
int64 bottom_pad, DeviceMemory<float>* output_data) { int64 bottom_pad, DeviceMemory<float>* output_data) {
LOG(FATAL) << "not yet implemented"; // TODO(leary) LOG(FATAL) << "not yet implemented"; // TODO(leary)
return false;
} }
bool CudnnSupport::DoXYSlice(Stream* stream, bool CudnnSupport::DoXYSlice(Stream* stream,
@ -1994,6 +1999,7 @@ bool CudnnSupport::DoXYSlice(Stream* stream,
int64 bottom_trim, int64 bottom_trim,
DeviceMemory<float>* output_data) { DeviceMemory<float>* output_data) {
LOG(FATAL) << "not yet implemented"; // TODO(leary) LOG(FATAL) << "not yet implemented"; // TODO(leary)
return false;
} }
bool CudnnSupport::DoMemcpyD2HQuantized( bool CudnnSupport::DoMemcpyD2HQuantized(

View File

@ -32,7 +32,7 @@ namespace cuda {
class CUDAExecutor; class CUDAExecutor;
// Opaque and unique identifer for the cuDNN plugin. // Opaque and unique identifier for the cuDNN plugin.
extern const PluginId kCuDnnPlugin; extern const PluginId kCuDnnPlugin;
// cudnn-library based DNN support. For details on overridden interface // cudnn-library based DNN support. For details on overridden interface

View File

@ -235,6 +235,8 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
} }
if (on_disk_spec != nullptr) { if (on_disk_spec != nullptr) {
LOG(WARNING) << "loading CUDA kernel from disk is not supported";
return false;
} else if (spec.has_cuda_ptx_in_memory()) { } else if (spec.has_cuda_ptx_in_memory()) {
kernelname = &spec.cuda_ptx_in_memory().kernelname(); kernelname = &spec.cuda_ptx_in_memory().kernelname();

View File

@ -49,6 +49,7 @@ string QuantizedActivationModeString(QuantizedActivationMode mode) {
LOG(FATAL) << "Unknown quantized_activation_mode " LOG(FATAL) << "Unknown quantized_activation_mode "
<< static_cast<int32>(mode); << static_cast<int32>(mode);
} }
return "unknown quantized_activation_mode";
} }
string ActivationModeString(ActivationMode mode) { string ActivationModeString(ActivationMode mode) {
@ -66,6 +67,7 @@ string ActivationModeString(ActivationMode mode) {
default: default:
LOG(FATAL) << "Unknown activation_mode " << static_cast<int32>(mode); LOG(FATAL) << "Unknown activation_mode " << static_cast<int32>(mode);
} }
return "unknown activation_mode";
} }
string ElementwiseOperationString(ElementwiseOperation op) { string ElementwiseOperationString(ElementwiseOperation op) {
@ -77,6 +79,7 @@ string ElementwiseOperationString(ElementwiseOperation op) {
default: default:
LOG(FATAL) << "Unknown elementwise op " << static_cast<int32>(op); LOG(FATAL) << "Unknown elementwise op " << static_cast<int32>(op);
} }
return "unknown element wise op";
} }
string DataLayoutString(DataLayout layout) { string DataLayoutString(DataLayout layout) {
@ -92,6 +95,7 @@ string DataLayoutString(DataLayout layout) {
default: default:
LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout); LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout);
} }
return "unknown data layout";
} }
string FilterLayoutString(FilterLayout layout) { string FilterLayoutString(FilterLayout layout) {
@ -105,6 +109,7 @@ string FilterLayoutString(FilterLayout layout) {
default: default:
LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(layout); LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(layout);
} }
return "unknown filter layout";
} }
string ShortPoolingModeString(PoolingMode mode) { string ShortPoolingModeString(PoolingMode mode) {
@ -116,6 +121,7 @@ string ShortPoolingModeString(PoolingMode mode) {
default: default:
LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode); LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode);
} }
return "unknown filter layout";
} }
std::tuple<int, int, int> GetDimIndices(const DataLayout& layout, std::tuple<int, int, int> GetDimIndices(const DataLayout& layout,
@ -166,7 +172,7 @@ std::vector<int64> ReorderDims(const std::vector<int64>& input,
reordered[b_idx_to] = input[b_idx_from]; reordered[b_idx_to] = input[b_idx_from];
reordered[d_idx_to] = input[d_idx_from]; reordered[d_idx_to] = input[d_idx_from];
for (int i = 0; i < input.size() - 2; for (size_t i = 0; i < input.size() - 2;
i++, spatial_idx_from++, spatial_idx_to++) { i++, spatial_idx_from++, spatial_idx_to++) {
reordered[spatial_idx_to] = input[spatial_idx_from]; reordered[spatial_idx_to] = input[spatial_idx_from];
} }

View File

@ -354,7 +354,7 @@ class FilterDescriptor {
// Arguments: // Arguments:
// - zero_padding_height: padding of the "y dimension" of the input data. Note // - zero_padding_height: padding of the "y dimension" of the input data. Note
// that this is different from the height of the filter. // that this is different from the height of the filter.
// - zero_padding_width: analogouus to the height above, but in the "x // - zero_padding_width: analogous to the height above, but in the "x
// dimension". // dimension".
// - vertical_filter_stride: the convolution slides a 2-dimensional window of // - vertical_filter_stride: the convolution slides a 2-dimensional window of
// filter-height-by-filter-width over the input layer -- the center of that // filter-height-by-filter-width over the input layer -- the center of that
@ -767,7 +767,7 @@ class DnnSupport {
// filter_descriptor: dimensions of the convolution filter. // filter_descriptor: dimensions of the convolution filter.
// filter_data: coefficients for the convolution filter. // filter_data: coefficients for the convolution filter.
// output_descriptor: dimensions of the output gradients, which is the same // output_descriptor: dimensions of the output gradients, which is the same
// as the dimensions of the ouput. // as the dimensions of the output.
// backward_output_data: un-owned device memory region which contains the // backward_output_data: un-owned device memory region which contains the
// backprop of the output. // backprop of the output.
// convolution_descriptor: stride of the convolution filter. // convolution_descriptor: stride of the convolution filter.
@ -813,7 +813,7 @@ class DnnSupport {
// input_data: un-owned device memory region which contains the // input_data: un-owned device memory region which contains the
// convolution input. // convolution input.
// output_descriptor: dimensions of the output gradients, which is the same // output_descriptor: dimensions of the output gradients, which is the same
// as the dimensions of the ouput. // as the dimensions of the output.
// backward_output_data: un-owned device memory region which contains the // backward_output_data: un-owned device memory region which contains the
// backprop of the output. // backprop of the output.
// convolution_descriptor: stride of the convolution filter. // convolution_descriptor: stride of the convolution filter.

View File

@ -63,10 +63,13 @@ class DeviceMemory;
class Timer; class Timer;
namespace dnn { namespace dnn {
struct BatchDescriptor; class BatchDescriptor;
struct FilterDescriptor; class FilterDescriptor;
struct ConvolutionDescriptor; class ConvolutionDescriptor;
struct ProfileResult; class BatchDescriptor;
class FilterDescriptor;
class ConvolutionDescriptor;
class ProfileResult;
typedef int64 AlgorithmType; typedef int64 AlgorithmType;
} // namespace dnn } // namespace dnn
@ -1257,7 +1260,7 @@ class Stream {
// back-end implementation will be appropriately seeded by default. // back-end implementation will be appropriately seeded by default.
// At a minimum 16 bytes of data are required in the seed buffer. // At a minimum 16 bytes of data are required in the seed buffer.
// //
// To seed with good (non-reproducable) data: // To seed with good (non-reproducible) data:
// File* f = File::Open("/dev/random", "r"); // File* f = File::Open("/dev/random", "r");
// int64 bytes_read = f->Read(seed_data, bytes_to_read); // int64 bytes_read = f->Read(seed_data, bytes_to_read);
// < error checking > // < error checking >
@ -1297,7 +1300,7 @@ class Stream {
uint64 size); uint64 size);
// Alternative interface for memcpying from device to host that takes an // Alternative interface for memcpying from device to host that takes an
// array slice. Checks that the destination size can accomodate the host // array slice. Checks that the destination size can accommodate the host
// slice size. // slice size.
template <typename T> template <typename T>
Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src, Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
@ -1308,7 +1311,7 @@ class Stream {
} }
// Alternative interface for memcpying from host to device that takes an // Alternative interface for memcpying from host to device that takes an
// array slice. Checks that the destination size can accomodate the host // array slice. Checks that the destination size can accommodate the host
// slice size. // slice size.
template <typename T> template <typename T>
Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src, Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
@ -1339,7 +1342,7 @@ class Stream {
// Entrain onto the stream: a memset of a 32-bit pattern at a GPU location // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location
// of // of
// size bytes, where bytes must be evenly 32-bit sized (i.e. evently // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly
// divisible // divisible
// by 4). The location must not be null. // by 4). The location must not be null.
Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern, Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern,

View File

@ -50,10 +50,6 @@ string StackTraceIfVLOG10() {
} }
} }
// Maximum stack depth to report when generating backtrace on mem allocation
// (for GPU memory leak checker)
static const int kMaxStackDepth = 256;
// Make sure the executor is done with its work; we know (because this isn't // Make sure the executor is done with its work; we know (because this isn't
// publicly visible) that all enqueued work is quick. // publicly visible) that all enqueued work is quick.
void BlockOnThreadExecutor(port::ThreadPool *executor) { void BlockOnThreadExecutor(port::ThreadPool *executor) {

View File

@ -119,7 +119,7 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | sed -e 's/=/_/g' -e 's/,/-/g')
DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]') DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]')
# Print arguments. # Print arguments.
echo "WORKSAPCE: ${WORKSPACE}" echo "WORKSPACE: ${WORKSPACE}"
echo "CI_DOCKER_EXTRA_PARAMS: ${CI_DOCKER_EXTRA_PARAMS[@]}" echo "CI_DOCKER_EXTRA_PARAMS: ${CI_DOCKER_EXTRA_PARAMS[@]}"
echo "COMMAND: ${COMMAND[@]}" echo "COMMAND: ${COMMAND[@]}"
echo "CI_COMMAND_PREFIX: ${CI_COMMAND_PREFIX[@]}" echo "CI_COMMAND_PREFIX: ${CI_COMMAND_PREFIX[@]}"

View File

@ -157,7 +157,7 @@ cc_library(
# This rule checks if Cuda libraries in the source tree has been properly configured. # This rule checks if Cuda libraries in the source tree has been properly configured.
# The output list makes bazel runs this rule first if the Cuda files are missing. # The output list makes bazel runs this rule first if the Cuda files are missing.
# This gives us an opportunity to check and print a meaningful error message. # This gives us an opportunity to check and print a meaningful error message.
# But we will need to create the output file list to make bazel happy in a successfull run. # But we will need to create the output file list to make bazel happy in a successful run.
genrule( genrule(
name = "cuda_check", name = "cuda_check",
srcs = [ srcs = [