TensorFlow: Minor updates to docs, BUILD, GPU config / perf, etc.
Changes: - Updates to op documentation and index by Josh - More changes to BUILD files for python 3 support by @girving - Fix to Eigen to use DenseIndex everywhere by @jiayq - Enable configuration for cuda compute capability by @zheng-xq, including updates to docs. - Route aggregation method through optimizer by schuster - Updates to install instructions for bazel 0.1.1. Base CL: 107702099
This commit is contained in:
parent
f2102f4e2c
commit
4dffee7f62
configuresix.BUILD
tensorflow
core
common_runtime
framework
kernels
g3doc
models
python
tensorboard
tensorflow.bzltools
third_party
eigen3/unsupported/Eigen/CXX11/src/Tensor/g3doc
gpus/crosstool/clang/bin
64
configure
vendored
64
configure
vendored
@ -76,6 +76,70 @@ CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH"
|
||||
CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH"
|
||||
EOF
|
||||
|
||||
function UnofficialSetting() {
|
||||
echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n"
|
||||
|
||||
# Configure the compute capabilities that TensorFlow builds for.
|
||||
# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
|
||||
while true; do
|
||||
fromuser=""
|
||||
if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
|
||||
cat << EOF
|
||||
Please specify a list of comma-separated Cuda compute capabilities you want to build with.
|
||||
You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
|
||||
Please note that each additional compute capability significantly increases your build time and binary size.
|
||||
EOF
|
||||
read -p "[Default is: \"3.5,5.2\"]: " TF_CUDA_COMPUTE_CAPABILITIES
|
||||
fromuser=1
|
||||
fi
|
||||
# Check whether all capabilities from the input is valid
|
||||
COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES//,/ }
|
||||
ALL_VALID=1
|
||||
for CAPABILITY in $COMPUTE_CAPABILITIES; do
|
||||
if [[ ! "$CAPABILITY" =~ [0-9]+.[0-9]+ ]]; then
|
||||
echo "Invalid compute capability: " $CAPABILITY
|
||||
ALL_VALID=0
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ "$ALL_VALID" == "0" ]; then
|
||||
if [ -z "$fromuser" ]; then
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
break
|
||||
fi
|
||||
TF_CUDA_COMPUTE_CAPABILITIES=""
|
||||
done
|
||||
|
||||
if [ ! -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
|
||||
export WARNING="Unofficial setting. DO NOT"" SUBMIT!!!"
|
||||
function CudaGenCodeOpts() {
|
||||
OUTPUT=""
|
||||
for CAPABILITY in $@; do
|
||||
OUTPUT=${OUTPUT}" \"${CAPABILITY}\", "
|
||||
done
|
||||
echo $OUTPUT
|
||||
}
|
||||
export CUDA_GEN_CODES_OPTS=$(CudaGenCodeOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
|
||||
perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\[).*?(\]),\n\1# $ENV{WARNING}\n\1\2$ENV{CUDA_GEN_CODES_OPTS}\3,s' third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
|
||||
function CudaVersionOpts() {
|
||||
OUTPUT=""
|
||||
for CAPABILITY in $@; do
|
||||
OUTPUT=$OUTPUT"CudaVersion(\"${CAPABILITY}\"), "
|
||||
done
|
||||
echo $OUTPUT
|
||||
}
|
||||
export CUDA_VERSION_OPTS=$(CudaVersionOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
|
||||
perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\{).*?(\}),\n\1// $ENV{WARNING}\n\1\2$ENV{CUDA_VERSION_OPTS}\3,s' tensorflow/core/common_runtime/gpu/gpu_device.cc
|
||||
fi
|
||||
}
|
||||
|
||||
# Only run the unofficial settings when users explicitly choose to.
|
||||
if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then
|
||||
UnofficialSetting
|
||||
fi
|
||||
|
||||
# Invoke the cuda_config.sh and set up the TensorFlow's canonical view of the Cuda libraries
|
||||
(cd third_party/gpus/cuda; ./cuda_config.sh;) || exit -1
|
||||
|
||||
|
@ -9,4 +9,5 @@ py_library(
|
||||
name = "six",
|
||||
srcs = ["six.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
@ -294,6 +294,31 @@ Status ExecutorImpl::InferAllocAttr(
|
||||
const DeviceNameUtils::ParsedName& local_dev_name,
|
||||
AllocatorAttributes* attr) {
|
||||
Status s;
|
||||
// Note that it's possible for *n to be a Recv and *dst to be a Send,
|
||||
// so these two cases are not mutually exclusive.
|
||||
if (IsRecv(n)) {
|
||||
string src_name;
|
||||
s = GetNodeAttr(n->def(), "send_device", &src_name);
|
||||
if (!s.ok()) return s;
|
||||
DeviceNameUtils::ParsedName parsed_src_name;
|
||||
if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
|
||||
s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
|
||||
n->name());
|
||||
return s;
|
||||
}
|
||||
if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
|
||||
// Value is going to be the sink of an RPC.
|
||||
attr->set_nic_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
|
||||
} else if (local_dev_name.type == "CPU" && parsed_src_name.type == "GPU") {
|
||||
// Value is going to be the sink of a local DMA from GPU to CPU.
|
||||
attr->set_gpu_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
|
||||
} else {
|
||||
VLOG(2) << "default alloc case local type " << local_dev_name.type
|
||||
<< " remote type " << parsed_src_name.type;
|
||||
}
|
||||
}
|
||||
if (IsSend(dst)) {
|
||||
string dst_name;
|
||||
s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <algorithm>
|
||||
|
||||
//#include "base/commandlineflags.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
@ -590,10 +591,50 @@ static int GetMinGPUMultiprocessorCount() {
|
||||
return kDefaultMinGPUMultiprocessorCount;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct CudaVersion {
|
||||
// Initialize from version_name in the form of "3.5"
|
||||
explicit CudaVersion(const std::string& version_name) {
|
||||
size_t dot_pos = version_name.find('.');
|
||||
CHECK(dot_pos != string::npos);
|
||||
string major_str = version_name.substr(0, dot_pos);
|
||||
CHECK(strings::safe_strto32(major_str.c_str(), &major_part));
|
||||
string minor_str = version_name.substr(dot_pos + 1);
|
||||
CHECK(strings::safe_strto32(minor_str.c_str(), &minor_part));
|
||||
}
|
||||
CudaVersion() {}
|
||||
bool operator<(const CudaVersion& other) const {
|
||||
if (this->major_part != other.major_part) {
|
||||
return this->major_part < other.major_part;
|
||||
}
|
||||
return this->minor_part < other.minor_part;
|
||||
}
|
||||
friend std::ostream& operator<<(std::ostream& os,
|
||||
const CudaVersion& version) {
|
||||
os << version.major_part << "." << version.minor_part;
|
||||
return os;
|
||||
}
|
||||
int major_part = -1;
|
||||
int minor_part = -1;
|
||||
};
|
||||
|
||||
// "configure" uses the specific name to substitute the following string.
|
||||
// If you change it, make sure you modify "configure" as well.
|
||||
std::vector<CudaVersion> supported_cuda_compute_capabilities = {
|
||||
CudaVersion("3.5"), CudaVersion("5.2")};
|
||||
|
||||
} // namespace
|
||||
|
||||
void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
|
||||
auto gpu_manager = GPUMachineManager();
|
||||
int min_gpu_core_count = GetMinGPUMultiprocessorCount();
|
||||
if (gpu_manager) {
|
||||
CHECK(!supported_cuda_compute_capabilities.empty());
|
||||
CudaVersion min_supported_capability =
|
||||
*std::min_element(supported_cuda_compute_capabilities.begin(),
|
||||
supported_cuda_compute_capabilities.end());
|
||||
|
||||
auto visible_device_count = gpu_manager->VisibleDeviceCount();
|
||||
for (int i = 0; i < gpu_manager->VisibleDeviceCount(); ++i) {
|
||||
auto exec_status = gpu_manager->ExecutorForDevice(i);
|
||||
@ -602,17 +643,19 @@ void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
|
||||
}
|
||||
gpu::StreamExecutor* se = exec_status.ValueOrDie();
|
||||
const gpu::DeviceDescription& desc = se->GetDeviceDescription();
|
||||
int major, minor;
|
||||
if (!desc.cuda_compute_capability(&major, &minor)) {
|
||||
CudaVersion device_capability;
|
||||
if (!desc.cuda_compute_capability(&device_capability.major_part,
|
||||
&device_capability.minor_part)) {
|
||||
continue;
|
||||
}
|
||||
// Only consider GPUs with compute capability >= 3.5 (Kepler or
|
||||
// higher)
|
||||
if (major < 3 || (major == 3 && minor < 5)) {
|
||||
// Only GPUs with no less than the minimum supported compute capability is
|
||||
// accepted.
|
||||
if (device_capability < min_supported_capability) {
|
||||
LOG(INFO) << "Ignoring gpu device "
|
||||
<< "(" << GetShortDeviceDescription(i, desc) << ") "
|
||||
<< "with Cuda compute capability " << major << "." << minor
|
||||
<< ". The minimum required Cuda capability is 3.5.";
|
||||
<< "with Cuda compute capability " << device_capability
|
||||
<< ". The minimum required Cuda capability is "
|
||||
<< min_supported_capability << ".";
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -188,9 +188,9 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
// message arrives.
|
||||
Item* item = new Item;
|
||||
item->waiter = done;
|
||||
item->recv_alloc_attrs = recv_args.alloc_attrs;
|
||||
if (recv_args.device_context) {
|
||||
item->recv_dev_context = recv_args.device_context;
|
||||
item->recv_alloc_attrs = recv_args.alloc_attrs;
|
||||
item->recv_dev_context->Ref();
|
||||
}
|
||||
CHECK(table_.insert({key, item}).second);
|
||||
|
@ -98,9 +98,10 @@ class TensorSlice {
|
||||
// We allow NDIMS to be greater than dims(), in which case we will pad the
|
||||
// higher dimensions with trivial dimensions.
|
||||
template <int NDIMS>
|
||||
void FillIndicesAndSizes(const TensorShape& shape,
|
||||
Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
|
||||
Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const;
|
||||
void FillIndicesAndSizes(
|
||||
const TensorShape& shape,
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const;
|
||||
|
||||
// Interaction with other TensorSlices.
|
||||
|
||||
@ -162,8 +163,8 @@ class TensorSlice {
|
||||
|
||||
template <int NDIMS>
|
||||
void TensorSlice::FillIndicesAndSizes(
|
||||
const TensorShape& shape, Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
|
||||
Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const {
|
||||
const TensorShape& shape, Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const {
|
||||
CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
|
||||
<< "slices: shape = " << shape.DebugString()
|
||||
<< ", slice = " << DebugString();
|
||||
|
@ -18,9 +18,9 @@ void ConcatGPU(const GPUDevice& d,
|
||||
const std::vector<
|
||||
std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
|
||||
typename TTypes<T, 2>::Matrix* output) {
|
||||
Eigen::array<ptrdiff_t, 2> offset(0, 0);
|
||||
Eigen::array<Eigen::DenseIndex, 2> offset(0, 0);
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
Eigen::array<ptrdiff_t, 2> size = inputs[i]->dimensions();
|
||||
Eigen::array<Eigen::DenseIndex, 2> size = inputs[i]->dimensions();
|
||||
output->slice(offset, size).device(d) = *inputs[i];
|
||||
offset[1] += size[1];
|
||||
}
|
||||
|
@ -17,74 +17,7 @@ namespace tensorflow {
|
||||
FIFOQueue::FIFOQueue(int capacity, const DataTypeVector& component_dtypes,
|
||||
const std::vector<TensorShape>& component_shapes,
|
||||
const string& name)
|
||||
: QueueBase(component_dtypes, component_shapes, name),
|
||||
capacity_(capacity),
|
||||
closed_(false) {}
|
||||
|
||||
Status FIFOQueue::Initialize() {
|
||||
if (component_dtypes_.empty()) {
|
||||
return errors::InvalidArgument("Empty component types for queue ", name_);
|
||||
}
|
||||
if (!component_shapes_.empty() &&
|
||||
component_dtypes_.size() != component_shapes_.size()) {
|
||||
return errors::InvalidArgument("Different number of component types (",
|
||||
component_dtypes_.size(), ") vs. shapes (",
|
||||
component_shapes_.size(), ").");
|
||||
}
|
||||
|
||||
mutex_lock lock(mu_);
|
||||
queues_.reserve(num_components());
|
||||
for (int i = 0; i < num_components(); ++i) {
|
||||
queues_.push_back(SubQueue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(mrry): If these checks become a bottleneck, find a way to
|
||||
// reduce the number of times that they are called.
|
||||
Status FIFOQueue::ValidateTuple(const Tuple& tuple) {
|
||||
TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
|
||||
if (specified_shapes()) {
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape mismatch in tuple component ", i, ". Expected ",
|
||||
component_shapes_[i].ShortDebugString(), ", got ",
|
||||
tuple[i].shape().ShortDebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(mrry): If these checks become a bottleneck, find a way to
|
||||
// reduce the number of times that they are called.
|
||||
Status FIFOQueue::ValidateManyTuple(const Tuple& tuple) {
|
||||
TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
|
||||
const int64 batch_size = tuple[0].dim_size(0);
|
||||
if (specified_shapes()) {
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
// Expected shape is [batch_size] + component_shapes_[i]
|
||||
const TensorShape expected_shape = ManyOutShape(i, batch_size);
|
||||
if (!tuple[i].shape().IsSameSize(expected_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape mismatch in tuple component ", i, ". Expected ",
|
||||
expected_shape.ShortDebugString(), ", got ",
|
||||
tuple[i].shape().ShortDebugString());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 1; i < tuple.size(); ++i) {
|
||||
if (tuple[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All input tensors must have the same size in the 0th ",
|
||||
"dimension. Component ", i, " has ", tuple[i].dim_size(0),
|
||||
", and should have ", batch_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
: TypedQueue(capacity, component_dtypes, component_shapes, name) {}
|
||||
|
||||
void FIFOQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
|
||||
DCHECK_GT(queues_[0].size(), 0);
|
||||
@ -95,113 +28,6 @@ void FIFOQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
|
||||
}
|
||||
}
|
||||
|
||||
void FIFOQueue::Cancel(Action action, CancellationToken token) {
|
||||
DoneCallback callback = nullptr;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
std::deque<Attempt>* attempts =
|
||||
action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
|
||||
|
||||
for (Attempt& attempt : *attempts) {
|
||||
if (attempt.cancellation_token == token) {
|
||||
attempt.is_cancelled = true;
|
||||
if (action == kEnqueue) {
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Enqueue operation was cancelled"));
|
||||
} else {
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Dequeue operation was cancelled"));
|
||||
}
|
||||
std::swap(callback, attempt.done_callback);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (callback) {
|
||||
callback();
|
||||
FlushUnlocked();
|
||||
}
|
||||
}
|
||||
|
||||
void FIFOQueue::CloseAndCancel() {
|
||||
std::vector<DoneCallback> callbacks;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
closed_ = true;
|
||||
for (Attempt& attempt : enqueue_attempts_) {
|
||||
attempt.is_cancelled = true;
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Enqueue operation was cancelled"));
|
||||
callbacks.emplace_back(std::move(attempt.done_callback));
|
||||
}
|
||||
}
|
||||
for (const DoneCallback& callback : callbacks) {
|
||||
callback();
|
||||
}
|
||||
FlushUnlocked();
|
||||
}
|
||||
|
||||
bool FIFOQueue::TryAttemptLocked(Action action,
|
||||
std::vector<CleanUp>* clean_up) {
|
||||
std::deque<Attempt>* attempts =
|
||||
action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
|
||||
|
||||
bool progress = false;
|
||||
bool done = false;
|
||||
while (!done && !attempts->empty()) {
|
||||
if (attempts->front().is_cancelled) {
|
||||
if (action == kEnqueue) {
|
||||
LOG(INFO) << "Skipping cancelled enqueue attempt";
|
||||
} else {
|
||||
LOG(INFO) << "Skipping cancelled dequeue attempt";
|
||||
}
|
||||
attempts->pop_front();
|
||||
} else {
|
||||
Attempt* cur_attempt = &attempts->front();
|
||||
switch (cur_attempt->run_callback(cur_attempt)) {
|
||||
case kNoProgress:
|
||||
done = true;
|
||||
break;
|
||||
case kProgress:
|
||||
done = true;
|
||||
progress = true;
|
||||
break;
|
||||
case kComplete:
|
||||
progress = true;
|
||||
clean_up->emplace_back(std::move(cur_attempt->done_callback),
|
||||
cur_attempt->cancellation_token,
|
||||
cur_attempt->context->cancellation_manager());
|
||||
attempts->pop_front();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return progress;
|
||||
}
|
||||
|
||||
void FIFOQueue::FlushUnlocked() {
|
||||
std::vector<CleanUp> clean_up;
|
||||
Ref();
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
bool changed;
|
||||
do {
|
||||
changed = TryAttemptLocked(kEnqueue, &clean_up);
|
||||
changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
|
||||
} while (changed);
|
||||
}
|
||||
Unref();
|
||||
for (const auto& to_clean : clean_up) {
|
||||
if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
|
||||
// NOTE(mrry): We can safely ignore the return value of
|
||||
// DeregisterCallback because the mutex mu_ ensures that the
|
||||
// cleanup action only executes once.
|
||||
to_clean.cm->DeregisterCallback(to_clean.to_deregister);
|
||||
}
|
||||
to_clean.finished();
|
||||
}
|
||||
}
|
||||
|
||||
void FIFOQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
|
||||
DoneCallback callback) {
|
||||
CancellationManager* cm = ctx->cancellation_manager();
|
||||
@ -484,30 +310,6 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
|
||||
}
|
||||
}
|
||||
|
||||
void FIFOQueue::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
|
||||
DoneCallback callback) {
|
||||
if (cancel_pending_enqueues) {
|
||||
CloseAndCancel();
|
||||
callback();
|
||||
} else {
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
enqueue_attempts_.emplace_back(
|
||||
0, callback, ctx, CancellationManager::kInvalidToken,
|
||||
[this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (closed_) {
|
||||
attempt->context->SetStatus(errors::Aborted(
|
||||
"FIFOQueue '", name_, "' is already closed."));
|
||||
} else {
|
||||
closed_ = true;
|
||||
}
|
||||
return kComplete;
|
||||
});
|
||||
}
|
||||
FlushUnlocked();
|
||||
}
|
||||
}
|
||||
|
||||
Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
|
||||
TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "FIFOQueue"));
|
||||
TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
|
||||
|
@ -6,24 +6,21 @@
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/queue_base.h"
|
||||
#include "tensorflow/core/kernels/typed_queue.h"
|
||||
#include "tensorflow/core/platform/port.h"
|
||||
#include "tensorflow/core/public/tensor.h"
|
||||
#include "tensorflow/core/public/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class FIFOQueue : public QueueBase {
|
||||
class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
|
||||
public:
|
||||
FIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
|
||||
const std::vector<TensorShape>& component_shapes,
|
||||
const string& name);
|
||||
Status Initialize(); // Must be called before any other method.
|
||||
|
||||
// Implementations of QueueInterface methods --------------------------------
|
||||
|
||||
Status ValidateTuple(const Tuple& tuple) override;
|
||||
Status ValidateManyTuple(const Tuple& tuple) override;
|
||||
void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
|
||||
DoneCallback callback) override;
|
||||
void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
|
||||
@ -31,8 +28,6 @@ class FIFOQueue : public QueueBase {
|
||||
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
|
||||
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
|
||||
CallbackWithTuple callback) override;
|
||||
void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
|
||||
DoneCallback callback) override;
|
||||
Status MatchesNodeDef(const NodeDef& node_def) override;
|
||||
|
||||
int32 size() override {
|
||||
@ -40,80 +35,13 @@ class FIFOQueue : public QueueBase {
|
||||
return queues_[0].size();
|
||||
}
|
||||
|
||||
int32 capacity() const { return capacity_; }
|
||||
|
||||
private:
|
||||
enum Action { kEnqueue, kDequeue };
|
||||
|
||||
~FIFOQueue() override {}
|
||||
|
||||
TensorShape ManyOutShape(int i, int64 batch_size) {
|
||||
TensorShape shape({batch_size});
|
||||
shape.AppendShape(component_shapes_[i]);
|
||||
return shape;
|
||||
}
|
||||
|
||||
// Helper for dequeuing a single element from queues_.
|
||||
void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
void Cancel(Action action, CancellationToken token);
|
||||
|
||||
// Helper for cancelling all pending Enqueue(Many) operations when
|
||||
// Close is called with cancel_pending_enqueues.
|
||||
void CloseAndCancel();
|
||||
|
||||
// Tries to enqueue/dequeue (or close) based on whatever is at the
|
||||
// front of enqueue_attempts_/dequeue_attempts_. Appends to
|
||||
// *finished the callback for any finished attempt (so it may be
|
||||
// called once mu_ is released). Returns true if any progress was
|
||||
// made.
|
||||
struct CleanUp {
|
||||
CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
|
||||
: finished(f), to_deregister(ct), cm(cm) {}
|
||||
DoneCallback finished;
|
||||
CancellationToken to_deregister;
|
||||
CancellationManager* cm;
|
||||
};
|
||||
bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Tries to make progress on the enqueues or dequeues at the front
|
||||
// of the *_attempts_ queues.
|
||||
void FlushUnlocked();
|
||||
|
||||
const int32 capacity_;
|
||||
|
||||
mutex mu_;
|
||||
typedef std::deque<PersistentTensor> SubQueue;
|
||||
std::vector<SubQueue> queues_ GUARDED_BY(mu_);
|
||||
bool closed_ GUARDED_BY(mu_);
|
||||
|
||||
enum RunResult { kNoProgress, kProgress, kComplete };
|
||||
struct Attempt;
|
||||
typedef std::function<RunResult(Attempt*)> RunCallback;
|
||||
struct Attempt {
|
||||
int32 elements_requested;
|
||||
DoneCallback done_callback; // must be run outside mu_
|
||||
OpKernelContext* context;
|
||||
CancellationToken cancellation_token;
|
||||
RunCallback run_callback; // must be run while holding mu_
|
||||
bool is_cancelled;
|
||||
Tuple tuple;
|
||||
|
||||
Attempt(int32 elements_requested, DoneCallback done_callback,
|
||||
OpKernelContext* context, CancellationToken cancellation_token,
|
||||
RunCallback run_callback)
|
||||
: elements_requested(elements_requested),
|
||||
done_callback(done_callback),
|
||||
context(context),
|
||||
cancellation_token(cancellation_token),
|
||||
run_callback(run_callback),
|
||||
is_cancelled(false) {}
|
||||
};
|
||||
std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
|
||||
std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
|
||||
|
||||
static Status GetElementComponentFromBatch(const Tuple& tuple, int index,
|
||||
int component,
|
||||
OpKernelContext* ctx,
|
||||
|
@ -23,8 +23,8 @@ static void GetBandMatrix(int depth, int64 depth_radius,
|
||||
for (int row = 0; row < depth; ++row) {
|
||||
const int begin = std::max<int>(0, row - depth_radius);
|
||||
const int end = std::min<int64>(depth, row + depth_radius + 1);
|
||||
Eigen::DSizes<ptrdiff_t, 2> start(row, begin);
|
||||
Eigen::DSizes<ptrdiff_t, 2> sizes(1, end - begin);
|
||||
Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
|
||||
Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
|
||||
result->slice(start, sizes).setConstant(1.0f);
|
||||
}
|
||||
}
|
||||
|
@ -243,7 +243,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
|
||||
std::min(wpad / params.col_stride + 1, params.out_width);
|
||||
const int in_offset =
|
||||
(b * params.tensor_in_rows + h) * params.tensor_in_cols + w;
|
||||
Eigen::DSizes<ptrdiff_t, 2> in_indices(0, in_offset);
|
||||
Eigen::DSizes<Eigen::DenseIndex, 2> in_indices(0, in_offset);
|
||||
for (int ph = h_start; ph < h_end; ++ph) {
|
||||
for (int pw = w_start; pw < w_end; ++pw) {
|
||||
const int out_offset =
|
||||
|
@ -46,52 +46,14 @@ Status HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
|
||||
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
|
||||
int index) {
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (parent.dtype() == DT) { \
|
||||
TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
HANDLE_TYPE(DT_FLOAT);
|
||||
HANDLE_TYPE(DT_DOUBLE);
|
||||
HANDLE_TYPE(DT_INT32);
|
||||
HANDLE_TYPE(DT_UINT8);
|
||||
HANDLE_TYPE(DT_INT16);
|
||||
HANDLE_TYPE(DT_INT8);
|
||||
HANDLE_TYPE(DT_STRING);
|
||||
HANDLE_TYPE(DT_INT64);
|
||||
#undef HANDLE_TYPE
|
||||
return errors::Unimplemented("Unhandled data type: ", parent.dtype());
|
||||
}
|
||||
|
||||
// static
|
||||
Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
|
||||
int index) {
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (element.dtype() == DT) { \
|
||||
TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
HANDLE_TYPE(DT_FLOAT);
|
||||
HANDLE_TYPE(DT_DOUBLE);
|
||||
HANDLE_TYPE(DT_INT32);
|
||||
HANDLE_TYPE(DT_UINT8);
|
||||
HANDLE_TYPE(DT_INT16);
|
||||
HANDLE_TYPE(DT_INT8);
|
||||
HANDLE_TYPE(DT_STRING);
|
||||
HANDLE_TYPE(DT_INT64);
|
||||
#undef HANDLE_TYPE
|
||||
return errors::Unimplemented("Unhandled data type: ", element.dtype());
|
||||
}
|
||||
|
||||
QueueBase::QueueBase(const DataTypeVector& component_dtypes,
|
||||
QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
|
||||
const std::vector<TensorShape>& component_shapes,
|
||||
const string& name)
|
||||
: component_dtypes_(component_dtypes),
|
||||
: capacity_(capacity),
|
||||
component_dtypes_(component_dtypes),
|
||||
component_shapes_(component_shapes),
|
||||
name_(name) {}
|
||||
name_(name),
|
||||
closed_(false) {}
|
||||
|
||||
Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
|
||||
if (tuple.size() != static_cast<size_t>(num_components())) {
|
||||
@ -172,4 +134,221 @@ Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(mrry): If these checks become a bottleneck, find a way to
|
||||
// reduce the number of times that they are called.
|
||||
Status QueueBase::ValidateTuple(const Tuple& tuple) {
|
||||
TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
|
||||
if (specified_shapes()) {
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape mismatch in tuple component ", i, ". Expected ",
|
||||
component_shapes_[i].ShortDebugString(), ", got ",
|
||||
tuple[i].shape().ShortDebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(mrry): If these checks become a bottleneck, find a way to
|
||||
// reduce the number of times that they are called.
|
||||
Status QueueBase::ValidateManyTuple(const Tuple& tuple) {
|
||||
TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
|
||||
const int64 batch_size = tuple[0].dim_size(0);
|
||||
if (specified_shapes()) {
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
// Expected shape is [batch_size] + component_shapes_[i]
|
||||
const TensorShape expected_shape = ManyOutShape(i, batch_size);
|
||||
if (!tuple[i].shape().IsSameSize(expected_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape mismatch in tuple component ", i, ". Expected ",
|
||||
expected_shape.ShortDebugString(), ", got ",
|
||||
tuple[i].shape().ShortDebugString());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 1; i < tuple.size(); ++i) {
|
||||
if (tuple[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All input tensors must have the same size in the 0th ",
|
||||
"dimension. Component ", i, " has ", tuple[i].dim_size(0),
|
||||
", and should have ", batch_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void QueueBase::Cancel(Action action, CancellationToken token) {
|
||||
DoneCallback callback = nullptr;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
std::deque<Attempt>* attempts =
|
||||
action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
|
||||
|
||||
for (Attempt& attempt : *attempts) {
|
||||
if (attempt.cancellation_token == token) {
|
||||
attempt.is_cancelled = true;
|
||||
if (action == kEnqueue) {
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Enqueue operation was cancelled"));
|
||||
} else {
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Dequeue operation was cancelled"));
|
||||
}
|
||||
std::swap(callback, attempt.done_callback);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (callback) {
|
||||
callback();
|
||||
FlushUnlocked();
|
||||
}
|
||||
}
|
||||
|
||||
void QueueBase::CloseAndCancel() {
|
||||
std::vector<DoneCallback> callbacks;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
closed_ = true;
|
||||
for (Attempt& attempt : enqueue_attempts_) {
|
||||
attempt.is_cancelled = true;
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Enqueue operation was cancelled"));
|
||||
callbacks.emplace_back(std::move(attempt.done_callback));
|
||||
}
|
||||
}
|
||||
for (const DoneCallback& callback : callbacks) {
|
||||
callback();
|
||||
}
|
||||
FlushUnlocked();
|
||||
}
|
||||
|
||||
void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
|
||||
DoneCallback callback) {
|
||||
if (cancel_pending_enqueues) {
|
||||
CloseAndCancel();
|
||||
callback();
|
||||
} else {
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
enqueue_attempts_.emplace_back(
|
||||
0, callback, ctx, CancellationManager::kInvalidToken,
|
||||
[this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (closed_) {
|
||||
attempt->context->SetStatus(
|
||||
errors::Aborted("Queue '", name_, "' is already closed."));
|
||||
} else {
|
||||
closed_ = true;
|
||||
}
|
||||
return kComplete;
|
||||
});
|
||||
}
|
||||
FlushUnlocked();
|
||||
}
|
||||
}
|
||||
|
||||
bool QueueBase::TryAttemptLocked(Action action,
|
||||
std::vector<CleanUp>* clean_up) {
|
||||
std::deque<Attempt>* attempts =
|
||||
action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
|
||||
|
||||
bool progress = false;
|
||||
bool done = false;
|
||||
while (!done && !attempts->empty()) {
|
||||
if (attempts->front().is_cancelled) {
|
||||
if (action == kEnqueue) {
|
||||
LOG(INFO) << "Skipping cancelled enqueue attempt";
|
||||
} else {
|
||||
LOG(INFO) << "Skipping cancelled dequeue attempt";
|
||||
}
|
||||
attempts->pop_front();
|
||||
} else {
|
||||
Attempt* cur_attempt = &attempts->front();
|
||||
switch (cur_attempt->run_callback(cur_attempt)) {
|
||||
case kNoProgress:
|
||||
done = true;
|
||||
break;
|
||||
case kProgress:
|
||||
done = true;
|
||||
progress = true;
|
||||
break;
|
||||
case kComplete:
|
||||
progress = true;
|
||||
clean_up->emplace_back(std::move(cur_attempt->done_callback),
|
||||
cur_attempt->cancellation_token,
|
||||
cur_attempt->context->cancellation_manager());
|
||||
attempts->pop_front();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return progress;
|
||||
}
|
||||
|
||||
void QueueBase::FlushUnlocked() {
|
||||
std::vector<CleanUp> clean_up;
|
||||
Ref();
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
bool changed;
|
||||
do {
|
||||
changed = TryAttemptLocked(kEnqueue, &clean_up);
|
||||
changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
|
||||
} while (changed);
|
||||
}
|
||||
Unref();
|
||||
for (const auto& to_clean : clean_up) {
|
||||
if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
|
||||
// NOTE(mrry): We can safely ignore the return value of
|
||||
// DeregisterCallback because the mutex mu_ ensures that the
|
||||
// cleanup action only executes once.
|
||||
to_clean.cm->DeregisterCallback(to_clean.to_deregister);
|
||||
}
|
||||
to_clean.finished();
|
||||
}
|
||||
}
|
||||
|
||||
// Static method
|
||||
Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
|
||||
int index) {
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (parent.dtype() == DT) { \
|
||||
TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
HANDLE_TYPE(DT_FLOAT);
|
||||
HANDLE_TYPE(DT_DOUBLE);
|
||||
HANDLE_TYPE(DT_INT32);
|
||||
HANDLE_TYPE(DT_UINT8);
|
||||
HANDLE_TYPE(DT_INT16);
|
||||
HANDLE_TYPE(DT_INT8);
|
||||
HANDLE_TYPE(DT_STRING);
|
||||
HANDLE_TYPE(DT_INT64);
|
||||
#undef HANDLE_TYPE
|
||||
return errors::Unimplemented("Unhandled data type: ", parent.dtype());
|
||||
}
|
||||
|
||||
// Static method
|
||||
Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
|
||||
int index) {
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (element.dtype() == DT) { \
|
||||
TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
HANDLE_TYPE(DT_FLOAT);
|
||||
HANDLE_TYPE(DT_DOUBLE);
|
||||
HANDLE_TYPE(DT_INT32);
|
||||
HANDLE_TYPE(DT_UINT8);
|
||||
HANDLE_TYPE(DT_INT16);
|
||||
HANDLE_TYPE(DT_INT8);
|
||||
HANDLE_TYPE(DT_STRING);
|
||||
HANDLE_TYPE(DT_INT64);
|
||||
#undef HANDLE_TYPE
|
||||
return errors::Unimplemented("Unhandled data type: ", element.dtype());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1,6 +1,9 @@
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
|
||||
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/queue_interface.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -11,7 +14,7 @@
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Functionality common to QueueInterface implementations.
|
||||
// Functionality common to asynchronous QueueInterface implementations.
|
||||
class QueueBase : public QueueInterface {
|
||||
public:
|
||||
// As a possible value of 'capacity'.
|
||||
@ -23,7 +26,7 @@ class QueueBase : public QueueInterface {
|
||||
// which must either be empty (if the shapes are not specified) or
|
||||
// or have the same size as component_dtypes.
|
||||
// name: A name to use for the queue.
|
||||
QueueBase(const DataTypeVector& component_dtypes,
|
||||
QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
|
||||
const std::vector<TensorShape>& component_shapes,
|
||||
const string& name);
|
||||
|
||||
@ -32,12 +35,36 @@ class QueueBase : public QueueInterface {
|
||||
return component_dtypes_;
|
||||
}
|
||||
|
||||
Status ValidateTuple(const Tuple& tuple) override;
|
||||
Status ValidateManyTuple(const Tuple& tuple) override;
|
||||
|
||||
void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
|
||||
DoneCallback callback) override;
|
||||
|
||||
// Other public methods -----------------------------------------------------
|
||||
const std::vector<TensorShape>& component_shapes() const {
|
||||
return component_shapes_;
|
||||
}
|
||||
|
||||
int32 capacity() const { return capacity_; }
|
||||
|
||||
protected:
|
||||
enum Action { kEnqueue, kDequeue };
|
||||
enum RunResult { kNoProgress, kProgress, kComplete };
|
||||
|
||||
// Tries to enqueue/dequeue (or close) based on whatever is at the
|
||||
// front of enqueue_attempts_/dequeue_attempts_. Appends to
|
||||
// *finished the callback for any finished attempt (so it may be
|
||||
// called once mu_ is released). Returns true if any progress was
|
||||
// made.
|
||||
struct CleanUp {
|
||||
CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
|
||||
: finished(f), to_deregister(ct), cm(cm) {}
|
||||
DoneCallback finished;
|
||||
CancellationToken to_deregister;
|
||||
CancellationManager* cm;
|
||||
};
|
||||
|
||||
// Returns the number of components in a queue-element tuple.
|
||||
int32 num_components() const { return component_dtypes_.size(); }
|
||||
|
||||
@ -48,6 +75,12 @@ class QueueBase : public QueueInterface {
|
||||
// Code common to Validate*Tuple().
|
||||
Status ValidateTupleCommon(const Tuple& tuple) const;
|
||||
|
||||
TensorShape ManyOutShape(int i, int64 batch_size) {
|
||||
TensorShape shape({batch_size});
|
||||
shape.AppendShape(component_shapes_[i]);
|
||||
return shape;
|
||||
}
|
||||
|
||||
// Copies the index^th slice (in the first dimension) of parent into element.
|
||||
static Status CopySliceToElement(const Tensor& parent, Tensor* element,
|
||||
int index);
|
||||
@ -56,6 +89,19 @@ class QueueBase : public QueueInterface {
|
||||
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
|
||||
int index);
|
||||
|
||||
void Cancel(Action action, CancellationToken token);
|
||||
|
||||
// Helper for cancelling all pending Enqueue(Many) operations when
|
||||
// Close is called with cancel_pending_enqueues.
|
||||
void CloseAndCancel();
|
||||
|
||||
bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Tries to make progress on the enqueues or dequeues at the front
|
||||
// of the *_attempts_ queues.
|
||||
void FlushUnlocked();
|
||||
|
||||
~QueueBase() override {}
|
||||
|
||||
// Helpers for implementing MatchesNodeDef().
|
||||
@ -65,9 +111,37 @@ class QueueBase : public QueueInterface {
|
||||
Status MatchesNodeDefTypes(const NodeDef& node_def) const;
|
||||
Status MatchesNodeDefShapes(const NodeDef& node_def) const;
|
||||
|
||||
protected:
|
||||
const int32 capacity_;
|
||||
const DataTypeVector component_dtypes_;
|
||||
const std::vector<TensorShape> component_shapes_;
|
||||
const string name_;
|
||||
mutex mu_;
|
||||
bool closed_ GUARDED_BY(mu_);
|
||||
|
||||
struct Attempt;
|
||||
typedef std::function<RunResult(Attempt*)> RunCallback;
|
||||
struct Attempt {
|
||||
int32 elements_requested;
|
||||
DoneCallback done_callback; // must be run outside mu_
|
||||
OpKernelContext* context;
|
||||
CancellationToken cancellation_token;
|
||||
RunCallback run_callback; // must be run while holding mu_
|
||||
bool is_cancelled;
|
||||
Tuple tuple;
|
||||
|
||||
Attempt(int32 elements_requested, DoneCallback done_callback,
|
||||
OpKernelContext* context, CancellationToken cancellation_token,
|
||||
RunCallback run_callback)
|
||||
: elements_requested(elements_requested),
|
||||
done_callback(done_callback),
|
||||
context(context),
|
||||
cancellation_token(cancellation_token),
|
||||
run_callback(run_callback),
|
||||
is_cancelled(false) {}
|
||||
};
|
||||
std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
|
||||
std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QueueBase);
|
||||
};
|
||||
|
@ -6,7 +6,7 @@
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/queue_base.h"
|
||||
#include "tensorflow/core/kernels/typed_queue.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
@ -19,18 +19,16 @@
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RandomShuffleQueue : public QueueBase {
|
||||
class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > {
|
||||
public:
|
||||
RandomShuffleQueue(int32 capacity, int32 min_after_dequeue, int64 seed,
|
||||
int64 seed2, const DataTypeVector& component_dtypes,
|
||||
const std::vector<TensorShape>& component_shapes,
|
||||
const string& name);
|
||||
Status Initialize(); // Must be called before any other method.
|
||||
|
||||
Status Initialize() override; // Must be called before any other method.
|
||||
|
||||
// Implementations of QueueInterface methods --------------------------------
|
||||
|
||||
Status ValidateTuple(const Tuple& tuple) override;
|
||||
Status ValidateManyTuple(const Tuple& tuple) override;
|
||||
void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
|
||||
DoneCallback callback) override;
|
||||
void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
|
||||
@ -38,8 +36,6 @@ class RandomShuffleQueue : public QueueBase {
|
||||
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
|
||||
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
|
||||
CallbackWithTuple callback) override;
|
||||
void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
|
||||
DoneCallback callback) override;
|
||||
Status MatchesNodeDef(const NodeDef& node_def) override;
|
||||
|
||||
int32 size() override {
|
||||
@ -48,95 +44,30 @@ class RandomShuffleQueue : public QueueBase {
|
||||
}
|
||||
|
||||
private:
|
||||
enum Action { kEnqueue, kDequeue };
|
||||
|
||||
~RandomShuffleQueue() override {}
|
||||
|
||||
TensorShape ManyOutShape(int i, int batch_size) {
|
||||
TensorShape shape({batch_size});
|
||||
shape.AppendShape(component_shapes_[i]);
|
||||
return shape;
|
||||
}
|
||||
|
||||
// Helper for dequeuing a single random element from queues_.
|
||||
void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
void Cancel(Action action, CancellationToken token);
|
||||
|
||||
// Helper for cancelling all pending Enqueue(Many) operations when
|
||||
// Close is called with cancel_pending_enqueues.
|
||||
void CloseAndCancel();
|
||||
|
||||
// Tries to enqueue/dequeue (or close) based on whatever is at the
|
||||
// front of enqueue_attempts_/dequeue_attempts_. Appends to
|
||||
// *finished the callback for any finished attempt (so it may be
|
||||
// called once mu_ is released). Returns true if any progress was
|
||||
// made.
|
||||
struct CleanUp {
|
||||
CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
|
||||
: finished(f), to_deregister(ct), cm(cm) {}
|
||||
DoneCallback finished;
|
||||
CancellationToken to_deregister;
|
||||
CancellationManager* cm;
|
||||
};
|
||||
bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Tries to make progress on the enqueues or dequeues at the front
|
||||
// of the *_attempts_ queues.
|
||||
void FlushUnlocked();
|
||||
|
||||
const int32 capacity_;
|
||||
const int32 min_after_dequeue_;
|
||||
const int64 original_seed_;
|
||||
const int64 original_seed2_;
|
||||
|
||||
mutex mu_;
|
||||
typedef std::vector<PersistentTensor> SubQueue;
|
||||
std::vector<SubQueue> queues_ GUARDED_BY(mu_);
|
||||
bool closed_ GUARDED_BY(mu_);
|
||||
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
|
||||
random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_);
|
||||
|
||||
enum RunResult { kNoProgress, kProgress, kComplete };
|
||||
struct Attempt;
|
||||
typedef std::function<RunResult(Attempt*)> RunCallback;
|
||||
struct Attempt {
|
||||
int32 elements_requested;
|
||||
DoneCallback done_callback; // must be run outside mu_
|
||||
OpKernelContext* context;
|
||||
CancellationToken cancellation_token;
|
||||
RunCallback run_callback; // must be run while holding mu_
|
||||
bool is_cancelled;
|
||||
Tuple tuple;
|
||||
|
||||
Attempt(int32 elements_requested, DoneCallback done_callback,
|
||||
OpKernelContext* context, CancellationToken cancellation_token,
|
||||
RunCallback run_callback)
|
||||
: elements_requested(elements_requested),
|
||||
done_callback(done_callback),
|
||||
context(context),
|
||||
cancellation_token(cancellation_token),
|
||||
run_callback(run_callback),
|
||||
is_cancelled(false) {}
|
||||
};
|
||||
std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
|
||||
std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue);
|
||||
};
|
||||
|
||||
RandomShuffleQueue::RandomShuffleQueue(
|
||||
int capacity, int min_after_dequeue, int64 seed, int64 seed2,
|
||||
int32 capacity, int32 min_after_dequeue, int64 seed, int64 seed2,
|
||||
const DataTypeVector& component_dtypes,
|
||||
const std::vector<TensorShape>& component_shapes, const string& name)
|
||||
: QueueBase(component_dtypes, component_shapes, name),
|
||||
capacity_(capacity),
|
||||
: TypedQueue(capacity, component_dtypes, component_shapes, name),
|
||||
min_after_dequeue_(min_after_dequeue),
|
||||
original_seed_(seed),
|
||||
original_seed2_(seed2),
|
||||
closed_(false),
|
||||
generator_(&parent_generator_) {
|
||||
if (seed == 0 && seed2 == 0) {
|
||||
// If both seeds are unspecified, use completely random seeds.
|
||||
@ -147,71 +78,16 @@ RandomShuffleQueue::RandomShuffleQueue(
|
||||
}
|
||||
|
||||
Status RandomShuffleQueue::Initialize() {
|
||||
if (component_dtypes_.empty()) {
|
||||
return errors::InvalidArgument("Empty component types for queue ", name_);
|
||||
}
|
||||
if (!component_shapes_.empty() &&
|
||||
component_dtypes_.size() != component_shapes_.size()) {
|
||||
return errors::InvalidArgument("Different number of component types (",
|
||||
component_dtypes_.size(), ") vs. shapes (",
|
||||
component_shapes_.size(), ").");
|
||||
}
|
||||
Status s = TypedQueue::Initialize();
|
||||
if (!s.ok()) return s;
|
||||
|
||||
mutex_lock lock(mu_);
|
||||
queues_.reserve(num_components());
|
||||
for (int i = 0; i < num_components(); ++i) {
|
||||
queues_.push_back(SubQueue());
|
||||
queues_.back().reserve(min_after_dequeue_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(mrry): If these checks become a bottleneck, find a way to
|
||||
// reduce the number of times that they are called.
|
||||
Status RandomShuffleQueue::ValidateTuple(const Tuple& tuple) {
|
||||
TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
|
||||
if (specified_shapes()) {
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape mismatch in tuple component ", i, ". Expected ",
|
||||
component_shapes_[i].ShortDebugString(), ", got ",
|
||||
tuple[i].shape().ShortDebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(mrry): If these checks become a bottleneck, find a way to
|
||||
// reduce the number of times that they are called.
|
||||
Status RandomShuffleQueue::ValidateManyTuple(const Tuple& tuple) {
|
||||
TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
|
||||
const int64 batch_size = tuple[0].dim_size(0);
|
||||
if (specified_shapes()) {
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
// Expected shape is [batch_size] + component_shapes_[i]
|
||||
const TensorShape expected_shape = ManyOutShape(i, batch_size);
|
||||
if (!tuple[i].shape().IsSameSize(expected_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape mismatch in tuple component ", i, ". Expected ",
|
||||
expected_shape.ShortDebugString(), ", got ",
|
||||
tuple[i].shape().ShortDebugString());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 1; i < tuple.size(); ++i) {
|
||||
if (tuple[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All input tensors must have the same size in the 0th ",
|
||||
"dimension. Component ", i, " has ", tuple[i].dim_size(0),
|
||||
", and should have ", batch_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
|
||||
DCHECK_GT(queues_[0].size(), 0);
|
||||
int64 index = generator_() % queues_[0].size();
|
||||
@ -223,113 +99,6 @@ void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
|
||||
}
|
||||
}
|
||||
|
||||
void RandomShuffleQueue::Cancel(Action action, CancellationToken token) {
|
||||
DoneCallback callback = nullptr;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
std::deque<Attempt>* attempts =
|
||||
action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
|
||||
|
||||
for (Attempt& attempt : *attempts) {
|
||||
if (attempt.cancellation_token == token) {
|
||||
attempt.is_cancelled = true;
|
||||
if (action == kEnqueue) {
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Enqueue operation was cancelled"));
|
||||
} else {
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Dequeue operation was cancelled"));
|
||||
}
|
||||
std::swap(callback, attempt.done_callback);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (callback) {
|
||||
callback();
|
||||
FlushUnlocked();
|
||||
}
|
||||
}
|
||||
|
||||
void RandomShuffleQueue::CloseAndCancel() {
|
||||
std::vector<DoneCallback> callbacks;
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
closed_ = true;
|
||||
for (Attempt& attempt : enqueue_attempts_) {
|
||||
attempt.is_cancelled = true;
|
||||
attempt.context->SetStatus(
|
||||
errors::Cancelled("Enqueue operation was cancelled"));
|
||||
callbacks.emplace_back(std::move(attempt.done_callback));
|
||||
}
|
||||
}
|
||||
for (const DoneCallback& callback : callbacks) {
|
||||
callback();
|
||||
}
|
||||
FlushUnlocked();
|
||||
}
|
||||
|
||||
bool RandomShuffleQueue::TryAttemptLocked(
|
||||
Action action, std::vector<CleanUp>* clean_up) {
|
||||
std::deque<Attempt>* attempts =
|
||||
action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
|
||||
|
||||
bool progress = false;
|
||||
bool done = false;
|
||||
while (!done && !attempts->empty()) {
|
||||
if (attempts->front().is_cancelled) {
|
||||
if (action == kEnqueue) {
|
||||
LOG(INFO) << "Skipping cancelled enqueue attempt";
|
||||
} else {
|
||||
LOG(INFO) << "Skipping cancelled dequeue attempt";
|
||||
}
|
||||
attempts->pop_front();
|
||||
} else {
|
||||
Attempt* cur_attempt = &attempts->front();
|
||||
switch (cur_attempt->run_callback(cur_attempt)) {
|
||||
case kNoProgress:
|
||||
done = true;
|
||||
break;
|
||||
case kProgress:
|
||||
done = true;
|
||||
progress = true;
|
||||
break;
|
||||
case kComplete:
|
||||
progress = true;
|
||||
clean_up->emplace_back(std::move(cur_attempt->done_callback),
|
||||
cur_attempt->cancellation_token,
|
||||
cur_attempt->context->cancellation_manager());
|
||||
attempts->pop_front();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return progress;
|
||||
}
|
||||
|
||||
void RandomShuffleQueue::FlushUnlocked() {
|
||||
std::vector<CleanUp> clean_up;
|
||||
Ref();
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
bool changed;
|
||||
do {
|
||||
changed = TryAttemptLocked(kEnqueue, &clean_up);
|
||||
changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
|
||||
} while (changed);
|
||||
}
|
||||
Unref();
|
||||
for (const auto& to_clean : clean_up) {
|
||||
if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
|
||||
// NOTE(mrry): We can safely ignore the return value of
|
||||
// DeregisterCallback because the mutex mu_ ensures that the
|
||||
// cleanup action only executes once.
|
||||
to_clean.cm->DeregisterCallback(to_clean.to_deregister);
|
||||
}
|
||||
to_clean.finished();
|
||||
}
|
||||
}
|
||||
|
||||
void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
|
||||
DoneCallback callback) {
|
||||
CancellationManager* cm = ctx->cancellation_manager();
|
||||
@ -583,31 +352,6 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
|
||||
}
|
||||
}
|
||||
|
||||
void RandomShuffleQueue::Close(OpKernelContext* ctx,
|
||||
bool cancel_pending_enqueues,
|
||||
DoneCallback callback) {
|
||||
if (cancel_pending_enqueues) {
|
||||
CloseAndCancel();
|
||||
callback();
|
||||
} else {
|
||||
{
|
||||
mutex_lock lock(mu_);
|
||||
enqueue_attempts_.emplace_back(
|
||||
0, callback, ctx, CancellationManager::kInvalidToken,
|
||||
[this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (closed_) {
|
||||
attempt->context->SetStatus(errors::Aborted(
|
||||
"RandomShuffleQueue '", name_, "' is already closed."));
|
||||
} else {
|
||||
closed_ = true;
|
||||
}
|
||||
return kComplete;
|
||||
});
|
||||
}
|
||||
FlushUnlocked();
|
||||
}
|
||||
}
|
||||
|
||||
Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
|
||||
TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "RandomShuffleQueue"));
|
||||
TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
|
||||
@ -640,8 +384,6 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
typedef std::shared_ptr<QueueInterface> QueueInterfacePtr;
|
||||
|
||||
// Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one
|
||||
// backed by RandomShuffleQueue) that persists across different graph
|
||||
// executions, and sessions. Running this op produces a single-element
|
||||
|
@ -171,8 +171,8 @@ class SliceOp : public OpKernel {
|
||||
template <int NDIM>
|
||||
void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
|
||||
const gtl::ArraySlice<int64>& size, Tensor* result) {
|
||||
Eigen::DSizes<ptrdiff_t, NDIM> indices;
|
||||
Eigen::DSizes<ptrdiff_t, NDIM> sizes;
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
indices[i] = begin[i];
|
||||
sizes[i] = size[i];
|
||||
@ -205,8 +205,8 @@ namespace functor {
|
||||
void Slice<GPUDevice, T, NDIM>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
|
||||
typename TTypes<T, NDIM>::ConstTensor input, \
|
||||
const Eigen::DSizes<ptrdiff_t, NDIM>& indices, \
|
||||
const Eigen::DSizes<ptrdiff_t, NDIM>& sizes); \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
|
||||
extern template struct Slice<GPUDevice, T, NDIM>;
|
||||
|
||||
#define DECLARE_FOR_N(T) \
|
||||
|
@ -13,8 +13,8 @@ template <typename Device, typename T, int NDIMS>
|
||||
struct Slice {
|
||||
void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
|
||||
typename TTypes<T, NDIMS>::ConstTensor input,
|
||||
const Eigen::DSizes<ptrdiff_t, NDIMS>& slice_indices,
|
||||
const Eigen::DSizes<ptrdiff_t, NDIMS>& slice_sizes) {
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_indices,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_sizes) {
|
||||
output.device(d) = input.slice(slice_indices, slice_sizes);
|
||||
}
|
||||
};
|
||||
|
@ -90,17 +90,17 @@ class SplitOp : public OpKernel {
|
||||
TensorShape output_shape(input_shape);
|
||||
output_shape.set_dim(split_dim, split_dim_output_size);
|
||||
|
||||
Eigen::DSizes<ptrdiff_t, 3> indices{0, 0, 0};
|
||||
Eigen::DSizes<ptrdiff_t, 3> sizes{prefix_dim_size, split_dim_output_size,
|
||||
suffix_dim_size};
|
||||
Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
|
||||
Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
|
||||
prefix_dim_size, split_dim_output_size, suffix_dim_size};
|
||||
|
||||
for (int i = 0; i < num_split; ++i) {
|
||||
Tensor* result = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(i, output_shape, &result));
|
||||
if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
|
||||
Eigen::DSizes<ptrdiff_t, 3> slice_indices;
|
||||
Eigen::DSizes<ptrdiff_t, 3> slice_sizes;
|
||||
Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
|
||||
Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
slice_indices[j] = indices[j];
|
||||
slice_sizes[j] = sizes[j];
|
||||
|
@ -12,8 +12,8 @@ template <typename Device, typename T>
|
||||
struct Split {
|
||||
void operator()(const Device& d, typename TTypes<T, 3>::Tensor output,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
|
||||
const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes);
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -21,8 +21,8 @@ struct Split<Eigen::ThreadPoolDevice, T> {
|
||||
void operator()(const Eigen::ThreadPoolDevice& d,
|
||||
typename TTypes<T, 3>::Tensor output,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
|
||||
const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes);
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
@ -13,8 +13,8 @@ template <typename T>
|
||||
void Split<Eigen::ThreadPoolDevice, T>::operator()(
|
||||
const Eigen::ThreadPoolDevice& d, typename TTypes<T, 3>::Tensor output,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
|
||||
const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
|
||||
if (output.size() < 131072) {
|
||||
output = input.slice(slice_indices, slice_sizes);
|
||||
} else {
|
||||
|
@ -16,8 +16,8 @@ template <typename Device, typename T>
|
||||
void Split<Device, T>::operator()(
|
||||
const Device& d, typename TTypes<T, 3>::Tensor output,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
|
||||
const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
|
||||
output.device(d) = input.slice(slice_indices, slice_sizes);
|
||||
}
|
||||
|
||||
|
@ -273,8 +273,8 @@ class TileGradientOp : public OpKernel {
|
||||
#undef HANDLE_DIM
|
||||
}
|
||||
|
||||
Eigen::DSizes<ptrdiff_t, NDIM> indices;
|
||||
Eigen::DSizes<ptrdiff_t, NDIM> sizes;
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
|
||||
|
||||
// Accumulate slices along the dimensions into the output. The number of
|
||||
// slices along dimension 'i' is simply the multiple along dimension 'i'
|
||||
@ -309,8 +309,8 @@ class TileGradientOp : public OpKernel {
|
||||
void HandleReduce(OpKernelContext* context,
|
||||
const std::vector<int32>& reduce_dim_in, Tensor* result) {
|
||||
static_assert(NDIM >= REDUCENDIM, "Too many reduced dimensions");
|
||||
Eigen::DSizes<ptrdiff_t, REDUCENDIM> reduce_dim;
|
||||
Eigen::DSizes<ptrdiff_t, NDIM> reshape_dim;
|
||||
Eigen::DSizes<Eigen::DenseIndex, REDUCENDIM> reduce_dim;
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIM> reshape_dim;
|
||||
|
||||
for (int i = 0; i < REDUCENDIM; ++i) {
|
||||
reduce_dim[i] = reduce_dim_in[i];
|
||||
@ -392,26 +392,26 @@ REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
DEFINE_GPU_DIM(T, 4) \
|
||||
DEFINE_GPU_DIM(T, 5)
|
||||
|
||||
#define DEFINE_GPU_DIM(T, NDIM) \
|
||||
template <> \
|
||||
void Tile<GPUDevice, T, NDIM>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::array<int32, NDIM>& broadcast_array) const; \
|
||||
extern template struct Tile<GPUDevice, T, NDIM>; \
|
||||
template <> \
|
||||
void TileGrad<GPUDevice, T, NDIM>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::DSizes<ptrdiff_t, NDIM>& indices, \
|
||||
const Eigen::DSizes<ptrdiff_t, NDIM>& sizes, bool first) const; \
|
||||
extern template struct TileGrad<GPUDevice, T, NDIM>; \
|
||||
template <> \
|
||||
void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::DSizes<ptrdiff_t, 1>& reduce_dim, \
|
||||
const Eigen::DSizes<ptrdiff_t, NDIM>& reshape_dim) const; \
|
||||
#define DEFINE_GPU_DIM(T, NDIM) \
|
||||
template <> \
|
||||
void Tile<GPUDevice, T, NDIM>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::array<int32, NDIM>& broadcast_array) const; \
|
||||
extern template struct Tile<GPUDevice, T, NDIM>; \
|
||||
template <> \
|
||||
void TileGrad<GPUDevice, T, NDIM>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes, bool first) const; \
|
||||
extern template struct TileGrad<GPUDevice, T, NDIM>; \
|
||||
template <> \
|
||||
void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 1>& reduce_dim, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const; \
|
||||
extern template struct ReduceAndReshape<GPUDevice, T, NDIM, 1>;
|
||||
|
||||
namespace functor {
|
||||
|
@ -31,8 +31,8 @@ template <typename Device, typename T, int NDIM>
|
||||
struct TileGrad {
|
||||
void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
|
||||
typename TTypes<T, NDIM>::ConstTensor in,
|
||||
const Eigen::DSizes<ptrdiff_t, NDIM>& indices,
|
||||
const Eigen::DSizes<ptrdiff_t, NDIM>& sizes,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes,
|
||||
bool first) const {
|
||||
if (first) {
|
||||
out.device(d) = in.slice(indices, sizes);
|
||||
@ -58,10 +58,11 @@ struct TileGrad<Device, T, 0> {
|
||||
|
||||
template <typename Device, typename T, int NDIM, int REDUCEDNDIM>
|
||||
struct ReduceAndReshape {
|
||||
void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
|
||||
typename TTypes<T, NDIM>::ConstTensor in,
|
||||
const Eigen::DSizes<ptrdiff_t, REDUCEDNDIM>& reduce_dim,
|
||||
const Eigen::DSizes<ptrdiff_t, NDIM>& reshape_dim) const {
|
||||
void operator()(
|
||||
const Device& d, typename TTypes<T, NDIM>::Tensor out,
|
||||
typename TTypes<T, NDIM>::ConstTensor in,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, REDUCEDNDIM>& reduce_dim,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const {
|
||||
out.device(d) = in.sum(reduce_dim).reshape(reshape_dim);
|
||||
}
|
||||
};
|
||||
|
54
tensorflow/core/kernels/typed_queue.h
Normal file
54
tensorflow/core/kernels/typed_queue.h
Normal file
@ -0,0 +1,54 @@
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/kernels/queue_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TypedQueue builds on QueueBase, with backing class (SubQueue)
|
||||
// known and stored within. Shared methods that need to have access
|
||||
// to the backed data sit in this class.
|
||||
template <typename SubQueue>
|
||||
class TypedQueue : public QueueBase {
|
||||
public:
|
||||
TypedQueue(const int32 capacity, const DataTypeVector& component_dtypes,
|
||||
const std::vector<TensorShape>& component_shapes,
|
||||
const string& name);
|
||||
|
||||
virtual Status Initialize(); // Must be called before any other method.
|
||||
|
||||
protected:
|
||||
std::vector<SubQueue> queues_ GUARDED_BY(mu_);
|
||||
}; // class TypedQueue
|
||||
|
||||
template <typename SubQueue>
|
||||
TypedQueue<SubQueue>::TypedQueue(
|
||||
int32 capacity, const DataTypeVector& component_dtypes,
|
||||
const std::vector<TensorShape>& component_shapes, const string& name)
|
||||
: QueueBase(capacity, component_dtypes, component_shapes, name) {}
|
||||
|
||||
template <typename SubQueue>
|
||||
Status TypedQueue<SubQueue>::Initialize() {
|
||||
if (component_dtypes_.empty()) {
|
||||
return errors::InvalidArgument("Empty component types for queue ", name_);
|
||||
}
|
||||
if (!component_shapes_.empty() &&
|
||||
component_dtypes_.size() != component_shapes_.size()) {
|
||||
return errors::InvalidArgument("Different number of component types (",
|
||||
component_dtypes_.size(), ") vs. shapes (",
|
||||
component_shapes_.size(), ").");
|
||||
}
|
||||
|
||||
mutex_lock lock(mu_);
|
||||
queues_.reserve(num_components());
|
||||
for (int i = 0; i < num_components(); ++i) {
|
||||
queues_.push_back(SubQueue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
|
@ -63,8 +63,8 @@ class UnpackOp : public OpKernel {
|
||||
context->allocate_output(i, output_shape, &output));
|
||||
auto output_shaped = output->shaped<T, 3>({1, 1, output_size});
|
||||
|
||||
Eigen::DSizes<ptrdiff_t, 3> indices{0, i, 0};
|
||||
Eigen::DSizes<ptrdiff_t, 3> sizes{1, 1, output_size};
|
||||
Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, i, 0};
|
||||
Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, 1, output_size};
|
||||
functor::Split<Device, T>()(context->eigen_device<Device>(),
|
||||
output_shaped, input_reshaped, indices,
|
||||
sizes);
|
||||
|
@ -23,28 +23,37 @@ write the graph to a file.
|
||||
|
||||
1. Run the graph with a call to `session->Run()`
|
||||
|
||||
|
||||
##Classes <a class="md-anchor" id="AUTOGENERATED-classes"></a>
|
||||
## Env <a class="md-anchor" id="AUTOGENERATED-env"></a>
|
||||
|
||||
* [tensorflow::Env](../../api_docs/cc/ClassEnv.md)
|
||||
* [tensorflow::RandomAccessFile](../../api_docs/cc/ClassRandomAccessFile.md)
|
||||
* [tensorflow::WritableFile](../../api_docs/cc/ClassWritableFile.md)
|
||||
* [tensorflow::EnvWrapper](../../api_docs/cc/ClassEnvWrapper.md)
|
||||
|
||||
## Session <a class="md-anchor" id="AUTOGENERATED-session"></a>
|
||||
|
||||
* [tensorflow::Session](../../api_docs/cc/ClassSession.md)
|
||||
* [tensorflow::SessionOptions](../../api_docs/cc/StructSessionOptions.md)
|
||||
|
||||
## Status <a class="md-anchor" id="AUTOGENERATED-status"></a>
|
||||
|
||||
* [tensorflow::Status](../../api_docs/cc/ClassStatus.md)
|
||||
* [tensorflow::Status::State](../../api_docs/cc/StructState.md)
|
||||
|
||||
## Tensor <a class="md-anchor" id="AUTOGENERATED-tensor"></a>
|
||||
|
||||
* [tensorflow::Tensor](../../api_docs/cc/ClassTensor.md)
|
||||
* [tensorflow::TensorShape](../../api_docs/cc/ClassTensorShape.md)
|
||||
* [tensorflow::TensorShapeUtils](../../api_docs/cc/ClassTensorShapeUtils.md)
|
||||
* [tensorflow::Thread](../../api_docs/cc/ClassThread.md)
|
||||
|
||||
##Structs <a class="md-anchor" id="AUTOGENERATED-structs"></a>
|
||||
|
||||
* [tensorflow::SessionOptions](../../api_docs/cc/StructSessionOptions.md)
|
||||
* [tensorflow::Status::State](../../api_docs/cc/StructState.md)
|
||||
* [tensorflow::TensorShapeDim](../../api_docs/cc/StructTensorShapeDim.md)
|
||||
* [tensorflow::TensorShapeUtils](../../api_docs/cc/ClassTensorShapeUtils.md)
|
||||
|
||||
## Thread <a class="md-anchor" id="AUTOGENERATED-thread"></a>
|
||||
|
||||
* [tensorflow::Thread](../../api_docs/cc/ClassThread.md)
|
||||
* [tensorflow::ThreadOptions](../../api_docs/cc/StructThreadOptions.md)
|
||||
|
||||
|
||||
|
||||
<div class='sections-order' style="display: none;">
|
||||
<!--
|
||||
<!-- ClassEnv.md -->
|
||||
@ -52,14 +61,14 @@ write the graph to a file.
|
||||
<!-- ClassWritableFile.md -->
|
||||
<!-- ClassEnvWrapper.md -->
|
||||
<!-- ClassSession.md -->
|
||||
<!-- StructSessionOptions.md -->
|
||||
<!-- ClassStatus.md -->
|
||||
<!-- StructState.md -->
|
||||
<!-- ClassTensor.md -->
|
||||
<!-- ClassTensorShape.md -->
|
||||
<!-- StructTensorShapeDim.md -->
|
||||
<!-- ClassTensorShapeUtils.md -->
|
||||
<!-- ClassThread.md -->
|
||||
<!-- StructSessionOptions.md -->
|
||||
<!-- StructState.md -->
|
||||
<!-- StructTensorShapeDim.md -->
|
||||
<!-- StructThreadOptions.md -->
|
||||
-->
|
||||
</div>
|
||||
|
@ -597,7 +597,7 @@ For so-called "global normalization" needed for convolutional filters pass
|
||||
|
||||
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
|
||||
|
||||
Two `Tensors`: `mean` and `variance`.
|
||||
Two `Tensor` objects: `mean` and `variance`.
|
||||
|
||||
|
||||
|
||||
|
@ -175,20 +175,20 @@ depends on.
|
||||
|
||||
|
||||
Follow instructions [here](http://bazel.io/docs/install.html) to install the
|
||||
dependencies for Bazel. Then download and build the Bazel source with the
|
||||
following commands:
|
||||
dependencies for Bazel. Then download bazel version 0.1.1 using the
|
||||
[installer for your system](https://github.com/bazelbuild/bazel/releases) and
|
||||
run the installer as mentioned there:
|
||||
|
||||
```bash
|
||||
$ git clone https://github.com/bazelbuild/bazel.git
|
||||
$ cd bazel
|
||||
$ git checkout tags/0.1.0
|
||||
$ ./compile.sh
|
||||
$ chmod +x PATH_TO_INSTALL.SH
|
||||
$ ./PATH_TO_INSTALL.SH --user
|
||||
```
|
||||
|
||||
These commands use the commit tag `0.1.0`, which is known to work with
|
||||
TensorFlow. `HEAD` may be unstable.
|
||||
Remember to replace `PATH_TO_INSTALL.SH` to point to the location where you
|
||||
downloaded the installer.
|
||||
|
||||
Add the executable `output/bazel` to your `$PATH` environment variable.
|
||||
Finally, follow the instructions in that script to place bazel into your binary
|
||||
path.
|
||||
|
||||
#### Install other dependencies <a class="md-anchor" id="AUTOGENERATED-install-other-dependencies"></a>
|
||||
|
||||
|
@ -15,6 +15,11 @@ system, we suggest you cite the paper above.
|
||||
You can use this [BibTeX entry](../resources/bib.md). As the project progresses, we
|
||||
may update the suggested citation with new papers.
|
||||
|
||||
Please only use the TensorFlow name and marks when accurately referencing this
|
||||
software distribution, and do not use our marks in a way that suggests you are
|
||||
endorsed by or otherwise affiliated with Google. When referring to our marks,
|
||||
please include the following attribution statement: "TensorFlow, the TensorFlow
|
||||
logo and any related marks are trademarks of Google Inc."
|
||||
|
||||
## Community <a class="md-anchor" id="AUTOGENERATED-community"></a>
|
||||
|
||||
|
@ -12,6 +12,7 @@ py_binary(
|
||||
srcs = [
|
||||
"word2vec.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_word2vec",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -24,6 +25,7 @@ py_binary(
|
||||
srcs = [
|
||||
"word2vec_optimized.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_word2vec",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -35,6 +37,7 @@ py_test(
|
||||
name = "word2vec_test",
|
||||
size = "small",
|
||||
srcs = ["word2vec_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":word2vec",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -45,6 +48,7 @@ py_test(
|
||||
name = "word2vec_optimized_test",
|
||||
size = "small",
|
||||
srcs = ["word2vec_optimized_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":word2vec_optimized",
|
||||
"//tensorflow:tensorflow_py",
|
||||
|
@ -10,6 +10,7 @@ py_binary(
|
||||
srcs = [
|
||||
"alexnet_benchmark.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
|
@ -8,6 +8,7 @@ exports_files(["LICENSE"])
|
||||
py_library(
|
||||
name = "cifar10_input",
|
||||
srcs = ["cifar10_input.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
@ -16,6 +17,7 @@ py_library(
|
||||
py_test(
|
||||
name = "cifar10_input_test",
|
||||
srcs = ["cifar10_input_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cifar10_input",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -27,6 +29,7 @@ py_test(
|
||||
py_library(
|
||||
name = "cifar10",
|
||||
srcs = ["cifar10.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cifar10_input",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -38,6 +41,7 @@ py_binary(
|
||||
srcs = [
|
||||
"cifar10_eval.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":cifar10",
|
||||
@ -49,6 +53,7 @@ py_binary(
|
||||
srcs = [
|
||||
"cifar10_train.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":cifar10",
|
||||
@ -60,6 +65,7 @@ py_binary(
|
||||
srcs = [
|
||||
"cifar10_multi_gpu_train.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":cifar10",
|
||||
|
@ -10,6 +10,7 @@ py_binary(
|
||||
srcs = [
|
||||
"convolutional.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
@ -24,6 +25,7 @@ py_test(
|
||||
"--self_test=True",
|
||||
],
|
||||
main = "convolutional.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
|
@ -14,6 +14,7 @@ py_library(
|
||||
srcs = [
|
||||
"linear.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
@ -23,6 +24,7 @@ py_test(
|
||||
name = "linear_test",
|
||||
size = "small",
|
||||
srcs = ["linear_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":linear",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -34,6 +36,7 @@ py_library(
|
||||
srcs = [
|
||||
"rnn_cell.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":linear",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -44,6 +47,7 @@ py_test(
|
||||
name = "rnn_cell_test",
|
||||
size = "small",
|
||||
srcs = ["rnn_cell_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":rnn_cell",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -55,6 +59,7 @@ py_library(
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":rnn",
|
||||
":rnn_cell",
|
||||
@ -67,6 +72,7 @@ py_library(
|
||||
srcs = [
|
||||
"rnn.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":rnn_cell",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -88,6 +94,7 @@ py_library(
|
||||
srcs = [
|
||||
"seq2seq.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":rnn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -99,6 +106,7 @@ py_test(
|
||||
srcs = [
|
||||
"seq2seq_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":seq2seq",
|
||||
"//tensorflow:tensorflow_py",
|
||||
|
@ -10,12 +10,14 @@ exports_files(["LICENSE"])
|
||||
py_library(
|
||||
name = "reader",
|
||||
srcs = ["reader.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "reader_test",
|
||||
srcs = ["reader_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":reader",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -27,6 +29,7 @@ py_binary(
|
||||
srcs = [
|
||||
"ptb_word_lm.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":reader",
|
||||
"//tensorflow:tensorflow_py",
|
||||
|
@ -12,6 +12,7 @@ py_library(
|
||||
srcs = [
|
||||
"data_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
@ -20,6 +21,7 @@ py_library(
|
||||
srcs = [
|
||||
"seq2seq_model.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":data_utils",
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -32,6 +34,7 @@ py_binary(
|
||||
srcs = [
|
||||
"translate.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":data_utils",
|
||||
":seq2seq_model",
|
||||
@ -49,6 +52,7 @@ py_test(
|
||||
"--self_test=True",
|
||||
],
|
||||
main = "translate.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":data_utils",
|
||||
":seq2seq_model",
|
||||
|
@ -27,6 +27,7 @@ numpy_macosx_include_dir = select({
|
||||
py_library(
|
||||
name = "python",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__pkg__"],
|
||||
deps = [
|
||||
":client",
|
||||
@ -43,6 +44,7 @@ py_library(
|
||||
py_library(
|
||||
name = "platform",
|
||||
srcs = glob(["platform/**/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -51,6 +53,7 @@ py_library(
|
||||
"platform/default/_googletest.py",
|
||||
"platform/googletest.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":platform"],
|
||||
)
|
||||
|
||||
@ -94,6 +97,7 @@ py_test(
|
||||
name = "pywrap_status_test",
|
||||
size = "small",
|
||||
srcs = ["lib/core/pywrap_status_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
@ -133,6 +137,7 @@ py_library(
|
||||
"framework/tensor_util.py",
|
||||
"ops/common_shapes.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":platform",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
@ -143,6 +148,7 @@ py_library(
|
||||
|
||||
py_library(
|
||||
name = "extra_py_tests_deps",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
@ -151,6 +157,7 @@ py_library(
|
||||
srcs = [
|
||||
"framework/test_util.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":platform_test",
|
||||
@ -165,6 +172,7 @@ py_library(
|
||||
srcs = [
|
||||
"platform/test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
@ -175,6 +183,7 @@ py_test(
|
||||
name = "framework_errors_test",
|
||||
srcs = ["framework/errors_test.py"],
|
||||
main = "framework/errors_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
@ -187,6 +196,7 @@ py_test(
|
||||
name = "framework_importer_test",
|
||||
srcs = ["framework/importer_test.py"],
|
||||
main = "framework/importer_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":ops",
|
||||
@ -213,6 +223,7 @@ py_test(
|
||||
name = "framework_ops_test",
|
||||
srcs = ["framework/ops_test.py"],
|
||||
main = "framework/ops_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":ops",
|
||||
@ -226,6 +237,19 @@ py_test(
|
||||
name = "framework_tensor_shape_test",
|
||||
srcs = ["framework/tensor_shape_test.py"],
|
||||
main = "framework/tensor_shape_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "framework_tensor_shape_div_test",
|
||||
srcs = ["framework/tensor_shape_div_test.py"],
|
||||
main = "framework/tensor_shape_div_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
@ -237,6 +261,7 @@ py_test(
|
||||
name = "framework_tensor_util_test",
|
||||
srcs = ["framework/tensor_util_test.py"],
|
||||
main = "framework/tensor_util_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":ops",
|
||||
@ -248,6 +273,7 @@ py_test(
|
||||
name = "framework_test_util_test",
|
||||
srcs = ["framework/test_util_test.py"],
|
||||
main = "framework/test_util_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":ops",
|
||||
@ -259,6 +285,7 @@ py_test(
|
||||
name = "framework_types_test",
|
||||
srcs = ["framework/types_test.py"],
|
||||
main = "framework/types_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
@ -271,6 +298,7 @@ py_test(
|
||||
name = "op_def_library_test",
|
||||
srcs = ["ops/op_def_library_test.py"],
|
||||
main = "ops/op_def_library_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":ops",
|
||||
@ -565,6 +593,7 @@ py_library(
|
||||
"ops/variables.py",
|
||||
"user_ops/user_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":candidate_sampling_ops",
|
||||
@ -591,6 +620,7 @@ py_library(
|
||||
["training/**/*.py"],
|
||||
exclude = ["**/*test*"],
|
||||
),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client",
|
||||
":framework",
|
||||
@ -609,6 +639,7 @@ py_library(
|
||||
["client/**/*.py"],
|
||||
exclude = ["**/*test*"],
|
||||
),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":ops",
|
||||
@ -620,6 +651,7 @@ py_library(
|
||||
py_library(
|
||||
name = "util",
|
||||
srcs = glob(["util/**/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//google/protobuf:protobuf_python"],
|
||||
)
|
||||
|
||||
@ -641,6 +673,7 @@ py_test(
|
||||
name = "protobuf_compare_test",
|
||||
srcs = ["util/protobuf/compare_test.py"],
|
||||
main = "util/protobuf/compare_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":compare_test_proto_py",
|
||||
":platform_test",
|
||||
@ -654,6 +687,7 @@ py_test(
|
||||
srcs = [
|
||||
"client/events_writer_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_test_lib",
|
||||
":lib",
|
||||
@ -719,6 +753,7 @@ tf_py_wrap_cc(
|
||||
py_library(
|
||||
name = "lib",
|
||||
srcs = glob(["lib/**/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pywrap_tensorflow",
|
||||
],
|
||||
@ -727,6 +762,7 @@ py_library(
|
||||
py_library(
|
||||
name = "session",
|
||||
srcs = ["client/session.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":ops",
|
||||
@ -750,6 +786,7 @@ tf_cuda_library(
|
||||
py_test(
|
||||
name = "session_test",
|
||||
srcs = ["client/session_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":framework_test_lib",
|
||||
@ -760,6 +797,7 @@ py_test(
|
||||
py_test(
|
||||
name = "graph_util_test",
|
||||
srcs = ["client/graph_util_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":framework_test_lib",
|
||||
@ -770,6 +808,7 @@ py_test(
|
||||
py_library(
|
||||
name = "kernel_tests/gradient_checker",
|
||||
srcs = ["kernel_tests/gradient_checker.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
cpu_only_kernel_test_list = glob([
|
||||
@ -899,6 +938,7 @@ py_library(
|
||||
["summary/**/*.py"],
|
||||
exclude = ["**/*test*"],
|
||||
),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client",
|
||||
":framework",
|
||||
@ -921,6 +961,7 @@ py_library(
|
||||
srcs = [
|
||||
"framework/docs.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":platform",
|
||||
],
|
||||
@ -932,6 +973,7 @@ py_binary(
|
||||
"framework/gen_docs_combined.py",
|
||||
],
|
||||
main = "framework/gen_docs_combined.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":docs",
|
||||
":platform",
|
||||
|
@ -170,18 +170,18 @@ class Dimension(object):
|
||||
def __floordiv__(self, other):
|
||||
"""Returns the quotient of `self` and `other` rounded down.
|
||||
|
||||
Dimensions are summed as follows:
|
||||
Dimensions are divided as follows:
|
||||
|
||||
Dimension(m) / Dimension(n) == Dimension(m / n)
|
||||
Dimension(m) / Dimension(None) == Dimension(None)
|
||||
Dimension(None) / Dimension(n) == Dimension(None)
|
||||
Dimension(None) / Dimension(None) == Dimension(None)
|
||||
Dimension(m) // Dimension(n) == Dimension(m // n)
|
||||
Dimension(m) // Dimension(None) == Dimension(None)
|
||||
Dimension(None) // Dimension(n) == Dimension(None)
|
||||
Dimension(None) // Dimension(None) == Dimension(None)
|
||||
|
||||
Args:
|
||||
other: Another Dimension.
|
||||
other: Another `Dimension`.
|
||||
|
||||
Returns:
|
||||
A Dimension whose value is the sum of `self` and `other`.
|
||||
A `Dimension` whose value is the integer quotient of `self` and `other`.
|
||||
"""
|
||||
other = as_dimension(other)
|
||||
if self._value is None or other.value is None:
|
||||
@ -189,6 +189,22 @@ class Dimension(object):
|
||||
else:
|
||||
return Dimension(self._value // other.value)
|
||||
|
||||
def __div__(self, other):
|
||||
"""DEPRECATED: Use `__floordiv__` via `x // y` instead.
|
||||
|
||||
This function exists only for backwards compatibility purposes; new code
|
||||
should use `__floordiv__` via the syntax `x // y`. Using `x // y`
|
||||
communicates clearly that the result rounds down, and is forward compatible
|
||||
to Python 3.
|
||||
|
||||
Args:
|
||||
other: Another `Dimension`.
|
||||
|
||||
Returns:
|
||||
A `Dimension` whose value is the integer quotient of `self` and `other`.
|
||||
"""
|
||||
return self // other
|
||||
|
||||
def __mod__(self, other):
|
||||
"""Returns `self` modulo `other.
|
||||
|
||||
|
24
tensorflow/python/framework/tensor_shape_div_test.py
Normal file
24
tensorflow/python/framework/tensor_shape_div_test.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Test that old style division works for Dimension."""
|
||||
from __future__ import absolute_import
|
||||
# from __future__ import division # Intentionally skip this import
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class DimensionDivTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDivSucceeds(self):
|
||||
"""Without from __future__ import division, __div__ should work."""
|
||||
values = [tensor_shape.Dimension(x) for x in 3, 7, 11, None]
|
||||
for x in values:
|
||||
for y in values:
|
||||
self.assertEqual((x / y).value, (x // y).value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -233,6 +233,12 @@ class ShapeTest(test_util.TensorFlowTestCase):
|
||||
tensor_shape.TensorShape(
|
||||
[94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43))
|
||||
|
||||
def testTruedivFails(self):
|
||||
unknown = tensor_shape.Dimension(None)
|
||||
self.assertEqual((unknown // unknown).value, None)
|
||||
with self.assertRaisesRegexp(TypeError, r"unsupported operand type"):
|
||||
unknown / unknown # pylint: disable=pointless-statement
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -409,7 +409,7 @@ def split(split_dim, num_split, value, name="split"):
|
||||
Args:
|
||||
split_dim: A 0-D `int32` `Tensor`. The dimension along which to split.
|
||||
Must be in the range `[0, rank(value))`.
|
||||
num_split: A 0-D `int32` `Tensor`. The number of ways to split.
|
||||
num_split: A Python integer. The number of ways to split.
|
||||
value: The `Tensor` to split.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
|
@ -138,7 +138,7 @@ class Optimizer(object):
|
||||
self._slots = {}
|
||||
|
||||
def minimize(self, loss, global_step=None, var_list=None,
|
||||
gate_gradients=GATE_OP, name=None):
|
||||
gate_gradients=GATE_OP, aggregation_method=None, name=None):
|
||||
"""Add operations to minimize 'loss' by updating 'var_list'.
|
||||
|
||||
This method simply combines calls compute_gradients() and
|
||||
@ -155,6 +155,8 @@ class Optimizer(object):
|
||||
under the key GraphKeys.TRAINABLE_VARIABLES.
|
||||
gate_gradients: How to gate the computation of gradients. Can be
|
||||
GATE_NONE, GATE_OP, or GATE_GRAPH.
|
||||
aggregation_method: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
name: Optional name for the returned operation.
|
||||
|
||||
Returns:
|
||||
@ -164,12 +166,14 @@ class Optimizer(object):
|
||||
Raises:
|
||||
ValueError: if some of the variables are not variables.Variable objects.
|
||||
"""
|
||||
grads_and_vars = self.compute_gradients(loss, var_list=var_list,
|
||||
gate_gradients=gate_gradients)
|
||||
grads_and_vars = self.compute_gradients(
|
||||
loss, var_list=var_list, gate_gradients=gate_gradients,
|
||||
aggregation_method=aggregation_method)
|
||||
return self.apply_gradients(grads_and_vars, global_step=global_step,
|
||||
name=name)
|
||||
|
||||
def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP):
|
||||
def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP,
|
||||
aggregation_method=None):
|
||||
"""Compute gradients of "loss" for the variables in "var_list".
|
||||
|
||||
This is the first part of minimize(). It returns a list
|
||||
@ -185,6 +189,8 @@ class Optimizer(object):
|
||||
under the key GraphKey.TRAINABLE_VARIABLES.
|
||||
gate_gradients: How to gate the computation of gradients. Can be
|
||||
GATE_NONE, GATE_OP, or GATE_GRAPH.
|
||||
aggregation_method: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
|
||||
Returns:
|
||||
A list of (gradient, variable) pairs.
|
||||
@ -205,7 +211,8 @@ class Optimizer(object):
|
||||
if not isinstance(var, variables.Variable):
|
||||
raise TypeError("Argument is not a variables.Variable: %s" % var)
|
||||
grads = gradients.gradients(
|
||||
loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP))
|
||||
loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP),
|
||||
aggregation_method=aggregation_method)
|
||||
if gate_gradients == Optimizer.GATE_GRAPH:
|
||||
grads = control_flow_ops.tuple(grads)
|
||||
grads_and_vars = list(zip(grads, var_list))
|
||||
|
54
tensorflow/python/training/optimizer_test.py
Normal file
54
tensorflow/python/training/optimizer_test.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""Functional test for optimizer."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import tensorflow.python.platform
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class OptimizerTest(tf.test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session():
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
var1 = tf.Variable([3.0, 4.0])
|
||||
cost = 5 * var0 + 3 * var1
|
||||
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
|
||||
sgd_op = tf.train.GradientDescentOptimizer(3.0)
|
||||
opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
# Run 1 step of sgd through optimizer
|
||||
opt_op.run()
|
||||
# Validate updated params
|
||||
self.assertAllClose([-14., -13.], var0.eval())
|
||||
self.assertAllClose([-6., -5.], var1.eval())
|
||||
|
||||
def testAggregationMethod(self):
|
||||
with self.test_session():
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
var1 = tf.Variable([3.0, 4.0])
|
||||
cost = 5 * var0 + 3 * var1
|
||||
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
|
||||
sgd_op = tf.train.GradientDescentOptimizer(3.0)
|
||||
opt_op = sgd_op.minimize(
|
||||
cost, global_step, [var0, var1], aggregation_method=
|
||||
tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
# Run 1 step of sgd through optimizer
|
||||
opt_op.run()
|
||||
# Validate updated params
|
||||
self.assertAllClose([-14., -13.], var0.eval())
|
||||
self.assertAllClose([-6., -5.], var1.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -20,11 +20,13 @@ py_library(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:summary",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "float_wrapper",
|
||||
srcs = ["float_wrapper.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -35,6 +37,7 @@ py_test(
|
||||
":float_wrapper",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_binary(
|
||||
@ -46,4 +49,5 @@ py_binary(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:summary",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
@ -332,7 +332,8 @@ def py_tests(name,
|
||||
deps=[
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:kernel_tests/gradient_checker",
|
||||
] + additional_deps)
|
||||
] + additional_deps,
|
||||
srcs_version="PY2AND3")
|
||||
|
||||
|
||||
def cuda_py_tests(name, srcs, additional_deps=[], data=[], shard_count=1):
|
||||
|
@ -10,6 +10,7 @@ exports_files(["LICENSE"])
|
||||
py_binary(
|
||||
name = "simple_console",
|
||||
srcs = ["simple_console.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
|
@ -7,6 +7,7 @@ py_binary(
|
||||
name = "simple_console",
|
||||
srcs = ["simple_console.py"],
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
sh_binary(
|
||||
|
@ -1195,7 +1195,7 @@ for the last dimension).
|
||||
input.setRandom();
|
||||
kernel.setRandom();
|
||||
|
||||
Eigen::array<ptrdiff_t, 2> dims({1, 2}); // Specify second and third dimension for convolution.
|
||||
Eigen::array<Eigen::DenseIndex, 2> dims({1, 2}); // Specify second and third dimension for convolution.
|
||||
output = input.convolve(kernel, dims);
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
@ -1577,7 +1577,7 @@ For example, given the following input tensor:
|
||||
Six 2x2 patches can be extracted and indexed using the following code:
|
||||
|
||||
Eigen::Tensor<float, 3, DataLayout> patch;
|
||||
Eigen::array<ptrdiff_t, 2> patch_dims;
|
||||
Eigen::array<Eigen::DenseIndex, 2> patch_dims;
|
||||
patch_dims[0] = 2;
|
||||
patch_dims[1] = 2;
|
||||
patch = tensor.extract_patches(patch_dims);
|
||||
|
@ -232,9 +232,14 @@ def InvokeNvcc(argv, log=False):
|
||||
srcs = ' '.join(src_files)
|
||||
out = ' -o ' + out_file[0]
|
||||
|
||||
nvccopts = ' '.join([
|
||||
r'-gencode=arch=compute_35,\"code=sm_35,compute_35\"',
|
||||
r'-gencode=arch=compute_52,\"code=sm_52,compute_52\"',])
|
||||
# "configure" uses the specific format to substitute the following string.
|
||||
# If you change it, make sure you modify "configure" as well.
|
||||
supported_cuda_compute_capabilities = [ "3.5", "5.2" ]
|
||||
nvccopts = ''
|
||||
for capability in supported_cuda_compute_capabilities:
|
||||
capability = capability.replace('.', '')
|
||||
nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
|
||||
capability, capability, capability)
|
||||
nvccopts += ' ' + nvcc_compiler_options
|
||||
nvccopts += undefines
|
||||
nvccopts += defines
|
||||
@ -260,8 +265,8 @@ def InvokeNvcc(argv, log=False):
|
||||
' -I .' +
|
||||
' -x cu ' + opt + includes + ' -c ' + srcs + out)
|
||||
|
||||
# TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
|
||||
# Need to investigate and fix.
|
||||
# TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
|
||||
# Need to investigate and fix.
|
||||
cmd = 'PATH=' + PREFIX_DIR + ' ' + cmd
|
||||
if log: Log(cmd)
|
||||
return os.system(cmd)
|
||||
|
Loading…
Reference in New Issue
Block a user