TensorFlow: conv improvements, label_image example, and
a few other changes. Changes: - Some improvements to convolution by using 32-bit indices by @benoitsteiner. Not all calls converted yet. Also some improvements to pooling as well by @benoitsteiner. - Improvements to sparse matmul CPU implementation by Ashish - Some fixes to warnings by @vrv - Doc fixes to padding by @Yangqing - Some improvements to Tensor wrappers by Eider - Speed up of matrix inverse on CPU by Rasmus - Add an example of doing image inference from a pre-trained model by @petewarden. - fixed formula in mnist example by nodir - Updates to event accumulator by Cassandra - Slight changes to tensor c api by @mrry - Handling of strings in listdiff by Phil - Fix negative fraction-of-queue-full stats by Frank - Type-checking improvement to importer by Yaroslav - logdir recursive search for Tensorboard by @danmane - Session.run() checks for empty graph by Manoj Base CL: 108013706
This commit is contained in:
parent
56313def00
commit
4213ac97be
@ -141,9 +141,9 @@ void TF_SetTarget(TF_SessionOptions* options, const char* target) {
|
||||
options->options.target = target;
|
||||
}
|
||||
|
||||
void TF_SetConfig(TF_SessionOptions* options, const char* config,
|
||||
size_t config_len, TF_Status* status) {
|
||||
if (!options->options.config.ParseFromArray(config, config_len)) {
|
||||
void TF_SetConfig(TF_SessionOptions* options, const void* proto,
|
||||
size_t proto_len, TF_Status* status) {
|
||||
if (!options->options.config.ParseFromArray(proto, proto_len)) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Unparseable ConfigProto");
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ TEST(GPUBFCAllocatorTest, NoDups) {
|
||||
std::sort(ptrs.begin(), ptrs.end());
|
||||
|
||||
// Make sure none of them are equal, and that none of them overlap.
|
||||
for (int i = 0; i < ptrs.size(); i++) {
|
||||
for (size_t i = 0; i < ptrs.size(); i++) {
|
||||
if (i > 0) {
|
||||
ASSERT_NE(ptrs[i], ptrs[i - 1]); // No dups
|
||||
size_t req_size = a.RequestedSize(ptrs[i - 1]);
|
||||
@ -40,7 +40,7 @@ TEST(GPUBFCAllocatorTest, NoDups) {
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < ptrs.size(); i++) {
|
||||
for (size_t i = 0; i < ptrs.size(); i++) {
|
||||
a.DeallocateRaw(ptrs[i]);
|
||||
}
|
||||
}
|
||||
@ -63,7 +63,7 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
|
||||
|
||||
// Deallocate half of the memory, and keep track of the others.
|
||||
std::vector<void*> existing_ptrs;
|
||||
for (int i = 0; i < initial_ptrs.size(); i++) {
|
||||
for (size_t i = 0; i < initial_ptrs.size(); i++) {
|
||||
if (i % 2 == 1) {
|
||||
a.DeallocateRaw(initial_ptrs[i]);
|
||||
} else {
|
||||
@ -81,7 +81,7 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
|
||||
|
||||
std::sort(existing_ptrs.begin(), existing_ptrs.end());
|
||||
// Make sure none of them are equal
|
||||
for (int i = 0; i < existing_ptrs.size(); i++) {
|
||||
for (size_t i = 0; i < existing_ptrs.size(); i++) {
|
||||
if (i > 0) {
|
||||
CHECK_NE(existing_ptrs[i], existing_ptrs[i - 1]); // No dups
|
||||
|
||||
@ -95,7 +95,7 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < existing_ptrs.size(); i++) {
|
||||
for (size_t i = 0; i < existing_ptrs.size(); i++) {
|
||||
a.DeallocateRaw(existing_ptrs[i]);
|
||||
}
|
||||
}
|
||||
|
@ -6,66 +6,73 @@
|
||||
namespace tensorflow {
|
||||
|
||||
// Helper to define Tensor types given that the scalar is of type T.
|
||||
template <typename T, int NDIMS = 1>
|
||||
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
|
||||
struct TTypes {
|
||||
// Rank-<NDIMS> tensor of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor>,
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned> Tensor;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor>,
|
||||
Eigen::Aligned> ConstTensor;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstTensor;
|
||||
|
||||
// Unaligned Rank-<NDIMS> tensor of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor> >
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType> >
|
||||
UnalignedTensor;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor> >
|
||||
UnalignedConstTensor;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor,
|
||||
IndexType> > UnalignedConstTensor;
|
||||
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
|
||||
Eigen::Aligned> Tensor32Bit;
|
||||
|
||||
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor>,
|
||||
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned> Scalar;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::TensorFixedSize<const T, Eigen::Sizes<>, Eigen::RowMajor>,
|
||||
Eigen::Aligned> ConstScalar;
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
|
||||
Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned> ConstScalar;
|
||||
|
||||
// Unaligned Scalar tensor of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<
|
||||
T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedScalar;
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<
|
||||
const T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedConstScalar;
|
||||
T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> > UnalignedScalar;
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
|
||||
Eigen::RowMajor, IndexType> >
|
||||
UnalignedConstScalar;
|
||||
|
||||
// Rank-1 tensor (vector) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned>
|
||||
Flat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned> ConstFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned>
|
||||
Vec;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned> ConstVec;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned> Flat;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned> Vec;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstVec;
|
||||
|
||||
// Unaligned Rank-1 tensor (vector) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> >
|
||||
UnalignedConstFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedVec;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> >
|
||||
UnalignedConstVec;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
|
||||
UnalignedFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor,
|
||||
IndexType> > UnalignedConstFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
|
||||
UnalignedVec;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> > UnalignedConstVec;
|
||||
|
||||
// Rank-2 tensor (matrix) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>
|
||||
Matrix;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
|
||||
Eigen::Aligned> ConstMatrix;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned> Matrix;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstMatrix;
|
||||
|
||||
// Unaligned Rank-2 tensor (matrix) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor> >
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType> >
|
||||
UnalignedMatrix;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor> >
|
||||
UnalignedConstMatrix;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor,
|
||||
IndexType> > UnalignedConstMatrix;
|
||||
};
|
||||
|
||||
typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32;
|
||||
|
@ -11,24 +11,25 @@ namespace functor {
|
||||
// TODO(yangke): revisit these operations and in particular, see if we can
|
||||
// combine all of them into just one operation without causing nvcc to
|
||||
// timeout.
|
||||
template <typename Device, typename T, int Dims>
|
||||
template <typename Device, typename T, int Dims, typename IndexType>
|
||||
struct ShuffleAndReverse {
|
||||
void operator()(const Device& d, typename TTypes<T, Dims>::ConstTensor input,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, Dims>& order,
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<T, Dims, IndexType>::ConstTensor input,
|
||||
const Eigen::DSizes<IndexType, Dims>& order,
|
||||
const Eigen::array<bool, Dims>& reverse_dims,
|
||||
typename TTypes<T, Dims>::Tensor output) {
|
||||
typename TTypes<T, Dims, IndexType>::Tensor output) {
|
||||
output.device(d) = input.shuffle(order).reverse(reverse_dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T, int Dims>
|
||||
template <typename Device, typename T, int Dims, typename IndexType>
|
||||
struct InflatePadAndShuffle {
|
||||
void operator()(
|
||||
const Device& d, typename TTypes<T, Dims>::ConstTensor input,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, Dims>& strides,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, Dims>& pad_dims,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, Dims>& order,
|
||||
typename TTypes<T, Dims>::Tensor output) {
|
||||
const Device& d, typename TTypes<T, Dims, IndexType>::ConstTensor input,
|
||||
const Eigen::DSizes<IndexType, Dims>& strides,
|
||||
const Eigen::array<Eigen::IndexPair<IndexType>, Dims>& pad_dims,
|
||||
const Eigen::DSizes<IndexType, Dims>& order,
|
||||
typename TTypes<T, Dims, IndexType>::Tensor output) {
|
||||
output.device(d) = input.inflate(strides).pad(pad_dims).shuffle(order);
|
||||
}
|
||||
};
|
||||
@ -89,30 +90,92 @@ struct MatMulConvFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename IndexType>
|
||||
struct TransformFilter {
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
|
||||
typename TTypes<T, 4>::Tensor out) {
|
||||
out.device(d) = in.shuffle(Eigen::DSizes<Eigen::DenseIndex, 4>(3, 2, 0, 1));
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<T, 4, IndexType>::ConstTensor in,
|
||||
typename TTypes<T, 4, IndexType>::Tensor out) {
|
||||
// We want a 3, 2, 0, 1 shuffle. We can merge dimensions 0 and 1 together
|
||||
// to help speedup the shuffle operation.
|
||||
Eigen::DSizes<IndexType, 3> merged_dims;
|
||||
merged_dims[0] = in.dimension(0) * in.dimension(1);
|
||||
merged_dims[1] = in.dimension(2);
|
||||
merged_dims[2] = in.dimension(3);
|
||||
|
||||
Eigen::DSizes<IndexType, 4> expanded_dims;
|
||||
expanded_dims[0] = in.dimension(3);
|
||||
expanded_dims[1] = in.dimension(2);
|
||||
expanded_dims[2] = in.dimension(0);
|
||||
expanded_dims[3] = in.dimension(1);
|
||||
|
||||
out.device(d) = in.reshape(merged_dims)
|
||||
.shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
|
||||
.reshape(expanded_dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename IndexType>
|
||||
struct TransformDepth {
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& shuffle,
|
||||
typename TTypes<T, 4>::Tensor out) {
|
||||
out.device(d) = in.shuffle(shuffle);
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<T, 4, IndexType>::ConstTensor in,
|
||||
const Eigen::DSizes<IndexType, 4>& shuffle,
|
||||
typename TTypes<T, 4, IndexType>::Tensor out) {
|
||||
Eigen::DSizes<IndexType, 3> merged_dims;
|
||||
Eigen::DSizes<IndexType, 4> expanded_dims;
|
||||
Eigen::DSizes<IndexType, 3> new_shuffle;
|
||||
|
||||
// Merge dimensions that won't be shuffled together to speed things up.
|
||||
if (shuffle[1] == 2 && shuffle[2] == 3) {
|
||||
merged_dims[0] = in.dimension(0);
|
||||
merged_dims[1] = in.dimension(1);
|
||||
merged_dims[2] = in.dimension(2) * in.dimension(3);
|
||||
new_shuffle[0] = shuffle[0];
|
||||
new_shuffle[1] = 2;
|
||||
new_shuffle[2] = shuffle[3];
|
||||
expanded_dims[0] = in.dimension(shuffle[0]);
|
||||
expanded_dims[1] = in.dimension(2);
|
||||
expanded_dims[2] = in.dimension(3);
|
||||
expanded_dims[3] = in.dimension(shuffle[3]);
|
||||
} else if (shuffle[0] == 2 && shuffle[1] == 3) {
|
||||
merged_dims[0] = in.dimension(0);
|
||||
merged_dims[1] = in.dimension(1);
|
||||
merged_dims[2] = in.dimension(2) * in.dimension(3);
|
||||
new_shuffle[0] = 2;
|
||||
new_shuffle[1] = shuffle[2];
|
||||
new_shuffle[2] = shuffle[3];
|
||||
expanded_dims[0] = in.dimension(2);
|
||||
expanded_dims[1] = in.dimension(3);
|
||||
expanded_dims[2] = in.dimension(shuffle[2]);
|
||||
expanded_dims[3] = in.dimension(shuffle[3]);
|
||||
} else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 &&
|
||||
shuffle[3] == 2) {
|
||||
merged_dims[0] = in.dimension(0);
|
||||
merged_dims[1] = in.dimension(1) * in.dimension(2);
|
||||
merged_dims[2] = in.dimension(3);
|
||||
new_shuffle[0] = 0;
|
||||
new_shuffle[1] = 2;
|
||||
new_shuffle[2] = 1;
|
||||
expanded_dims[0] = in.dimension(0);
|
||||
expanded_dims[1] = in.dimension(3);
|
||||
expanded_dims[2] = in.dimension(1);
|
||||
expanded_dims[3] = in.dimension(2);
|
||||
} else {
|
||||
assert(false && "unexpected shuffle");
|
||||
}
|
||||
|
||||
out.device(d) =
|
||||
in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename IndexType>
|
||||
struct PadInput {
|
||||
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<T, 4, IndexType>::ConstTensor in,
|
||||
int padding_rows_left, int padding_rows_right,
|
||||
int padding_cols_left, int padding_cols_right,
|
||||
typename TTypes<T, 4>::Tensor out) {
|
||||
Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, 4> padding;
|
||||
typename TTypes<T, 4, IndexType>::Tensor out) {
|
||||
Eigen::array<std::pair<IndexType, IndexType>, 4> padding;
|
||||
padding[0] = std::make_pair(0, 0);
|
||||
padding[1] = std::make_pair(padding_rows_left, padding_rows_right);
|
||||
padding[2] = std::make_pair(padding_cols_left, padding_cols_right);
|
||||
|
@ -783,9 +783,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
TensorShape({out_depth, in_depth, filter_rows, filter_cols}),
|
||||
&transformed_filter));
|
||||
|
||||
functor::TransformFilter<Device, T>()(context->eigen_device<Device>(),
|
||||
filter.tensor<T, 4>(),
|
||||
transformed_filter.tensor<T, 4>());
|
||||
functor::TransformFilter<Device, T, int>()(
|
||||
context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()),
|
||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||
|
||||
Tensor transformed_out_backprop;
|
||||
OP_REQUIRES_OK(
|
||||
@ -795,10 +795,10 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
TensorShape({batch, out_depth, output_rows, output_cols}),
|
||||
&transformed_out_backprop));
|
||||
|
||||
functor::TransformDepth<Device, T>()(
|
||||
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
|
||||
Eigen::DSizes<Eigen::DenseIndex, 4>(0, 3, 1, 2),
|
||||
transformed_out_backprop.tensor<T, 4>());
|
||||
functor::TransformDepth<Device, T, int>()(
|
||||
context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()),
|
||||
Eigen::DSizes<int, 4>(0, 3, 1, 2),
|
||||
To32Bit(transformed_out_backprop.tensor<T, 4>()));
|
||||
|
||||
Tensor pre_transformed_in_backprop;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -831,11 +831,12 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
}
|
||||
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
functor::TransformDepth<Device, T>()(
|
||||
functor::TransformDepth<Device, T, int>()(
|
||||
context->eigen_device<Device>(),
|
||||
toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
|
||||
Eigen::DSizes<Eigen::DenseIndex, 4>(0, 2, 3, 1),
|
||||
in_backprop->tensor<T, 4>());
|
||||
To32Bit(toConstTensor(pre_transformed_in_backprop)
|
||||
.template tensor<T, 4>()),
|
||||
Eigen::DSizes<int, 4>(0, 2, 3, 1),
|
||||
To32Bit(in_backprop->tensor<T, 4>()));
|
||||
} else {
|
||||
// We fill out a padded out_backprop
|
||||
TensorShape padded_out_shape(
|
||||
@ -852,7 +853,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
{left_pad_cols, right_pad_cols},
|
||||
{0, 0}}};
|
||||
|
||||
functor::InflatePadAndShuffle<Device, T, 4>()(
|
||||
functor::InflatePadAndShuffle<Device, T, 4, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides,
|
||||
pad_dims, trivial_order, padded_output.tensor<T, 4>());
|
||||
const Tensor& padded_output_cref = padded_output;
|
||||
@ -869,7 +870,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
|
||||
Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{0, 1, 3, 2};
|
||||
Eigen::array<bool, 4> filter_rev_dims{true, true, false, false};
|
||||
functor::ShuffleAndReverse<Device, T, 4>()(
|
||||
functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(), filter.tensor<T, 4>(), filter_order,
|
||||
filter_rev_dims, r_filter.tensor<T, 4>());
|
||||
const Tensor& r_filter_cref = r_filter;
|
||||
@ -1033,10 +1034,10 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
TensorShape({batch, out_depth, output_rows, output_cols}),
|
||||
&transformed_out_backprop));
|
||||
|
||||
functor::TransformDepth<Device, T>()(
|
||||
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
|
||||
Eigen::DSizes<Eigen::DenseIndex, 4>(0, 3, 1, 2),
|
||||
transformed_out_backprop.tensor<T, 4>());
|
||||
functor::TransformDepth<Device, T, int>()(
|
||||
context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()),
|
||||
Eigen::DSizes<int, 4>(0, 3, 1, 2),
|
||||
To32Bit(transformed_out_backprop.tensor<T, 4>()));
|
||||
|
||||
Tensor transformed_input;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -1045,10 +1046,10 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
TensorShape({batch, in_depth, input_rows, input_cols}),
|
||||
&transformed_input));
|
||||
|
||||
functor::TransformDepth<Device, T>()(
|
||||
context->eigen_device<Device>(), input.tensor<T, 4>(),
|
||||
Eigen::DSizes<Eigen::DenseIndex, 4>(0, 3, 1, 2),
|
||||
transformed_input.tensor<T, 4>());
|
||||
functor::TransformDepth<Device, T, int>()(
|
||||
context->eigen_device<Device>(), To32Bit(input.tensor<T, 4>()),
|
||||
Eigen::DSizes<int, 4>(0, 3, 1, 2),
|
||||
To32Bit(transformed_input.tensor<T, 4>()));
|
||||
|
||||
auto out_backprop_ptr =
|
||||
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
|
||||
@ -1074,12 +1075,12 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
}
|
||||
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
functor::TransformDepth<Device, T>()(
|
||||
functor::TransformDepth<Device, T, int>()(
|
||||
context->eigen_device<Device>(),
|
||||
toConstTensor(pre_transformed_filter_backprop)
|
||||
.template tensor<T, 4>(),
|
||||
Eigen::DSizes<Eigen::DenseIndex, 4>(2, 3, 1, 0),
|
||||
filter_backprop->tensor<T, 4>());
|
||||
To32Bit(toConstTensor(pre_transformed_filter_backprop)
|
||||
.template tensor<T, 4>()),
|
||||
Eigen::DSizes<int, 4>(2, 3, 1, 0),
|
||||
To32Bit(filter_backprop->tensor<T, 4>()));
|
||||
} else {
|
||||
// Fall back to the non-cudnn code path
|
||||
|
||||
@ -1102,7 +1103,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
{top_pad_rows, bottom_pad_rows},
|
||||
{left_pad_cols, right_pad_cols},
|
||||
{0, 0}}};
|
||||
functor::InflatePadAndShuffle<Device, T, 4>()(
|
||||
functor::InflatePadAndShuffle<Device, T, 4, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides,
|
||||
pad_dims, out_order, padded_output.tensor<T, 4>());
|
||||
const Tensor& padded_output_cref = padded_output;
|
||||
@ -1121,7 +1122,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
|
||||
// No need for reversing this time.
|
||||
Eigen::array<bool, 4> trivial_dims{false, false, false, false};
|
||||
functor::ShuffleAndReverse<Device, T, 4>()(
|
||||
functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(), input.tensor<T, 4>(), in_order,
|
||||
trivial_dims, in_shuffle.tensor<T, 4>());
|
||||
const Tensor& in_shuffle_cref = in_shuffle;
|
||||
@ -1149,7 +1150,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{1, 2, 3, 0};
|
||||
Eigen::array<bool, 4> filter_rev_dims{true, true, false, false};
|
||||
const Tensor& filter_shuffle_cref = filter_shuffle;
|
||||
functor::ShuffleAndReverse<Device, T, 4>()(
|
||||
functor::ShuffleAndReverse<Device, T, 4, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 4>(),
|
||||
filter_order, filter_rev_dims, filter_backprop->tensor<T, 4>());
|
||||
}
|
||||
@ -1165,46 +1166,65 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void ShuffleAndReverse<GPUDevice, T, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
|
||||
const Eigen::array<bool, 4>& reverse_dims, \
|
||||
typename TTypes<T, 4>::Tensor output); \
|
||||
extern template struct ShuffleAndReverse<GPUDevice, T, 4>; \
|
||||
template <> \
|
||||
void InflatePadAndShuffle<GPUDevice, T, 4>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& strides, \
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4>& pad_dims, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
|
||||
typename TTypes<T, 4>::Tensor output); \
|
||||
extern template struct InflatePadAndShuffle<GPUDevice, T, 4>; \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
|
||||
typename TTypes<T, 4>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T>; \
|
||||
template <> \
|
||||
void TransformDepth<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& shuffle, \
|
||||
typename TTypes<T, 4>::Tensor out); \
|
||||
extern template struct TransformDepth<GPUDevice, T>; \
|
||||
template <> \
|
||||
void SpatialConvolution<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
|
||||
typename TTypes<T, 4>::ConstTensor input, \
|
||||
typename TTypes<T, 4>::ConstTensor filter, int stride, \
|
||||
const Eigen::PaddingType& padding); \
|
||||
extern template struct SpatialConvolution<GPUDevice, T>; \
|
||||
template <> \
|
||||
void SpatialConvolutionBackwardInput<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::Tensor in_backprop, \
|
||||
typename TTypes<T, 4>::ConstTensor filter, \
|
||||
typename TTypes<T, 4>::ConstTensor output_backprop, int input_rows, \
|
||||
int input_cols, int stride); \
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void ShuffleAndReverse<GPUDevice, T, 4, Eigen::DenseIndex>::operator()( \
|
||||
const GPUDevice& d, \
|
||||
typename TTypes<T, 4, Eigen::DenseIndex>::ConstTensor input, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
|
||||
const Eigen::array<bool, 4>& reverse_dims, \
|
||||
typename TTypes<T, 4, Eigen::DenseIndex>::Tensor output); \
|
||||
extern template struct ShuffleAndReverse<GPUDevice, T, 4, \
|
||||
Eigen::DenseIndex>; \
|
||||
template <> \
|
||||
void InflatePadAndShuffle<GPUDevice, T, 4, Eigen::DenseIndex>::operator()( \
|
||||
const GPUDevice& d, \
|
||||
typename TTypes<T, 4, Eigen::DenseIndex>::ConstTensor input, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& strides, \
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4>& pad_dims, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
|
||||
typename TTypes<T, 4, Eigen::DenseIndex>::Tensor output); \
|
||||
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, \
|
||||
Eigen::DenseIndex>; \
|
||||
template <> \
|
||||
void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
|
||||
const Eigen::DSizes<int, 4>& order, \
|
||||
const Eigen::array<bool, 4>& reverse_dims, \
|
||||
typename TTypes<T, 4, int>::Tensor output); \
|
||||
extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>; \
|
||||
template <> \
|
||||
void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
|
||||
const Eigen::DSizes<int, 4>& strides, \
|
||||
const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims, \
|
||||
const Eigen::DSizes<int, 4>& order, \
|
||||
typename TTypes<T, 4, int>::Tensor output); \
|
||||
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T, int>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T, int>; \
|
||||
template <> \
|
||||
void TransformDepth<GPUDevice, T, int>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
const Eigen::DSizes<int, 4>& shuffle, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformDepth<GPUDevice, T, int>; \
|
||||
template <> \
|
||||
void SpatialConvolution<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
|
||||
typename TTypes<T, 4>::ConstTensor input, \
|
||||
typename TTypes<T, 4>::ConstTensor filter, int stride, \
|
||||
const Eigen::PaddingType& padding); \
|
||||
extern template struct SpatialConvolution<GPUDevice, T>; \
|
||||
template <> \
|
||||
void SpatialConvolutionBackwardInput<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::Tensor in_backprop, \
|
||||
typename TTypes<T, 4>::ConstTensor filter, \
|
||||
typename TTypes<T, 4>::ConstTensor output_backprop, int input_rows, \
|
||||
int input_cols, int stride); \
|
||||
extern template struct SpatialConvolutionBackwardInput<GPUDevice, T>
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
|
@ -167,6 +167,10 @@ class Conv2DOp : public BinaryOp<T> {
|
||||
<< ", filter_rows = " << filter_rows << ", stride = " << stride
|
||||
<< ", out_depth = " << out_depth;
|
||||
|
||||
// If there is nothing to compute, return.
|
||||
if (out_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
LaunchConvOp<Device, T>::launch(context, use_cudnn_, input, filter, stride,
|
||||
BrainPadding2EigenPadding(padding_),
|
||||
output);
|
||||
@ -260,10 +264,11 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
input.dim_size(2) + padding_cols, input.dim_size(3)}),
|
||||
&transformed_input));
|
||||
|
||||
functor::PadInput<GPUDevice, T>()(
|
||||
ctx->eigen_device<GPUDevice>(), input_param.tensor<T, 4>(),
|
||||
functor::PadInput<GPUDevice, T, int>()(
|
||||
ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
|
||||
padding_rows / 2, padding_rows - padding_rows / 2, padding_cols / 2,
|
||||
padding_cols - padding_cols / 2, transformed_input.tensor<T, 4>());
|
||||
padding_cols - padding_cols / 2,
|
||||
To32Bit(transformed_input.tensor<T, 4>()));
|
||||
input = transformed_input;
|
||||
}
|
||||
|
||||
@ -296,9 +301,9 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
filter.dim_size(0), filter.dim_size(1)}),
|
||||
&transformed_filter));
|
||||
|
||||
functor::TransformFilter<GPUDevice, T>()(
|
||||
ctx->eigen_device<GPUDevice>(), filter.tensor<T, 4>(),
|
||||
transformed_filter.tensor<T, 4>());
|
||||
functor::TransformFilter<GPUDevice, T, int>()(
|
||||
ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
|
||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||
|
||||
auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
||||
input.template flat<T>().size());
|
||||
@ -346,16 +351,16 @@ namespace functor {
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair); \
|
||||
extern template struct MatMulConvFunctor<GPUDevice, T>; \
|
||||
template <> \
|
||||
void TransformFilter<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
|
||||
typename TTypes<T, 4>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T>; \
|
||||
void TransformFilter<GPUDevice, T, int>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct TransformFilter<GPUDevice, T, int>; \
|
||||
template <> \
|
||||
void PadInput<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
|
||||
void PadInput<GPUDevice, T, int>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
|
||||
int padding_rows_left, int padding_rows_right, int padding_cols_left, \
|
||||
int padding_cols_right, typename TTypes<T, 4>::Tensor out); \
|
||||
extern template struct PadInput<GPUDevice, T>
|
||||
int padding_cols_right, typename TTypes<T, 4, int>::Tensor out); \
|
||||
extern template struct PadInput<GPUDevice, T, int>
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
@ -5,12 +5,14 @@
|
||||
#include "tensorflow/core/kernels/conv_2d.h"
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/public/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::InflatePadAndShuffle<GPUDevice, float, 4>;
|
||||
|
||||
template struct functor::InflatePadAndShuffle<GPUDevice, float, 4, int>;
|
||||
template struct functor::InflatePadAndShuffle<GPUDevice, float, 4,
|
||||
Eigen::DenseIndex>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -9,13 +9,18 @@
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::ShuffleAndReverse<GPUDevice, float, 4>;
|
||||
template struct functor::ShuffleAndReverse<GPUDevice, float, 4, int>;
|
||||
template struct functor::ShuffleAndReverse<GPUDevice, float, 4,
|
||||
Eigen::DenseIndex>;
|
||||
|
||||
template struct functor::TransformFilter<GPUDevice, float>;
|
||||
template struct functor::TransformFilter<GPUDevice, float, int>;
|
||||
|
||||
template struct functor::PadInput<GPUDevice, float>;
|
||||
template struct functor::PadInput<GPUDevice, float, int>;
|
||||
|
||||
template struct functor::TransformDepth<GPUDevice, float>;
|
||||
template struct functor::TransformDepth<GPUDevice, float, int>;
|
||||
// TODO(jiayq): currently pooling ops still use DenseIndex, so I am keeping it
|
||||
// here.
|
||||
template struct functor::TransformDepth<GPUDevice, float, Eigen::DenseIndex>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -86,11 +86,10 @@ class LinearAlgebraOp : public LinearAlgebraOpBase {
|
||||
explicit LinearAlgebraOp(OpKernelConstruction* context)
|
||||
: LinearAlgebraOpBase(context) {}
|
||||
|
||||
using ConstMatrixMap =
|
||||
Eigen::Map<const Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor>>;
|
||||
using MatrixMap = Eigen::Map<
|
||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
using Matrix =
|
||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||
using MatrixMap = Eigen::Map<Matrix>;
|
||||
|
||||
// Perform the actual computation on the input matrix, and store the results
|
||||
// in the output. This will be called repeatedly for a single call to
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
@ -70,6 +71,7 @@ class ListDiffOp : public OpKernel {
|
||||
ListDiffOp<type>)
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF);
|
||||
REGISTER_LISTDIFF(string);
|
||||
#undef REGISTER_LISTDIFF
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1,6 +1,7 @@
|
||||
// See docs in ../ops/linalg_ops.cc.
|
||||
#include <cmath>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Cholesky"
|
||||
#include "third_party/eigen3/Eigen/LU"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -35,6 +36,7 @@ class MatrixInverseOp
|
||||
}
|
||||
}
|
||||
|
||||
using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::Matrix;
|
||||
using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
|
||||
using
|
||||
typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ConstMatrixMap;
|
||||
@ -44,15 +46,36 @@ class MatrixInverseOp
|
||||
OP_REQUIRES(context, input.rows() == input.cols(),
|
||||
errors::InvalidArgument("Input matrix must be square."));
|
||||
if (input.rows() == 0) {
|
||||
// By definition, an empty matrix's inverse is an emptry matrix.
|
||||
// By definition, an empty matrix's inverse is an empty matrix.
|
||||
return;
|
||||
}
|
||||
Eigen::FullPivLU<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor>> lu_decomposition(input);
|
||||
OP_REQUIRES(context, lu_decomposition.isInvertible(),
|
||||
if (input.isApprox(input.transpose())) {
|
||||
// Matrix is symmetric, compute Cholesky factorization
|
||||
// input = L * L^T.
|
||||
Eigen::LLT<Matrix, Eigen::Lower> cholesky_decomposition(input);
|
||||
if (cholesky_decomposition.info() == Eigen::Success) {
|
||||
// Cholesky succeeded => Matrix was SPD.
|
||||
output->noalias() = cholesky_decomposition.solve(
|
||||
Matrix::Identity(input.rows(), input.cols()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
Eigen::PartialPivLU<Matrix> lu_decomposition(input);
|
||||
// While PartialPivLU cannot give strong guarantees on invertability,
|
||||
// we can at least guard against exact zero pivots. This can occur as
|
||||
// a result of basic user mistakes, such as providing integer valued
|
||||
// matrices that are exacly singular, or due to underflow if this
|
||||
// code is run with denormals being flushed to zero.
|
||||
// TODO(rmlarsen): Add check based on condition number estimation.
|
||||
const Scalar min_abs_pivot =
|
||||
lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff();
|
||||
OP_REQUIRES(context, min_abs_pivot > Scalar(0),
|
||||
errors::InvalidArgument("Input is not invertible."));
|
||||
*output = lu_decomposition.inverse();
|
||||
output->noalias() = lu_decomposition.inverse();
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MatrixInverseOp);
|
||||
};
|
||||
|
||||
REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<float, false>), float);
|
||||
|
@ -99,13 +99,13 @@ perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformDepth<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& shuffle, \
|
||||
typename TTypes<T, 4>::Tensor out); \
|
||||
extern template struct TransformDepth<GPUDevice, T>;
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void TransformDepth<GPUDevice, T, Eigen::DenseIndex>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 4>& shuffle, \
|
||||
typename TTypes<T, 4>::Tensor out); \
|
||||
extern template struct TransformDepth<GPUDevice, T, Eigen::DenseIndex>;
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
@ -172,7 +172,7 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
// For AvgPoolGrad, the original input tensor is not necessary. However,
|
||||
// cudnn still requires them to run, although they do not affect the
|
||||
// results.
|
||||
functor::TransformDepth<GPUDevice, T>()(
|
||||
functor::TransformDepth<GPUDevice, T, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(), tensor_in->tensor<T, 4>(),
|
||||
nhwc_to_nchw, transformed_input.tensor<T, 4>());
|
||||
}
|
||||
@ -180,11 +180,11 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
// For AvgPoolGrad, the original output tensor is not necessary. However,
|
||||
// cudnn still requires them to run, although they do not affect the
|
||||
// results.
|
||||
functor::TransformDepth<GPUDevice, T>()(
|
||||
functor::TransformDepth<GPUDevice, T, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(), tensor_out->tensor<T, 4>(),
|
||||
nhwc_to_nchw, transformed_output.tensor<T, 4>());
|
||||
}
|
||||
functor::TransformDepth<GPUDevice, T>()(
|
||||
functor::TransformDepth<GPUDevice, T, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
|
||||
nhwc_to_nchw, transformed_output_backprop.tensor<T, 4>());
|
||||
|
||||
@ -239,7 +239,7 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
/// Transform the output data from NCHW back to NHWC
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
auto nchw_to_nhwc = Eigen::DSizes<Eigen::DenseIndex, 4>(0, 2, 3, 1);
|
||||
functor::TransformDepth<GPUDevice, T>()(
|
||||
functor::TransformDepth<GPUDevice, T, Eigen::DenseIndex>()(
|
||||
context->eigen_device<Device>(),
|
||||
toConstTensor(transformed_input_backprop).template tensor<T, 4>(),
|
||||
nchw_to_nhwc, output->tensor<T, 4>());
|
||||
|
@ -3,71 +3,427 @@
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/port.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
#include "tensorflow/core/platform/port.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
typedef Eigen::Tensor<float, 2, Eigen::RowMajor> Matrix;
|
||||
typedef Eigen::DSizes<Eigen::DenseIndex, 2> DSizes;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>,
|
||||
Eigen::Aligned> MatrixMap;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor>,
|
||||
Eigen::Aligned> ConstMatrixMap;
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
template <typename T>
|
||||
void PrefetchBlockNTA(const T& tensor, int si, int ei, int sj, int ej) {
|
||||
for (int i = si; i < ei; ++i) {
|
||||
for (int j = sj; j < ej; j = j + 16) {
|
||||
port::prefetch<port::PREFETCH_HINT_NTA>(&tensor(i, j));
|
||||
}
|
||||
// Blocksizes
|
||||
// TODO(agarwal): compute these sizes based on cache sizes.
|
||||
static const int K = 64;
|
||||
static const int M = 64;
|
||||
static const int N = 128;
|
||||
|
||||
// This stores a sparse representation of a slice of a matrix with size
|
||||
// (num_rows, num_cols). The slice is represented as a series of blocks of size
|
||||
// (num_rows, b), where b = block_size for all but the last block, which may
|
||||
// have
|
||||
// fewer columns.
|
||||
//
|
||||
// num_rows and block_size are assumed to be <= 256. This allows storing
|
||||
// different indices as uint8.
|
||||
//
|
||||
// For each block, we store all the non zero entries in data/data3 vector and
|
||||
// the corresponding coordinates of the element in index/index3 vectors. index3
|
||||
// vector stores index of 3 elements in the same row so that these elements can
|
||||
// share the same row coordinate. Each entry in Index3 corresponds to 3 entries
|
||||
// in data3.
|
||||
//
|
||||
// Note that all the data/indices of all the blocks are stored in the same
|
||||
// vectors respectively. To identify block boundaries, we store the block
|
||||
// offsets using index3_offset/index_offset. If there are n blocks in the slice,
|
||||
// index3_offset and index_offset have n entires. The indices for the ith block
|
||||
// are the values in the following range:
|
||||
// [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
|
||||
// index_offset.
|
||||
struct SparseSlice {
|
||||
public:
|
||||
// Indices of three elements on the same row.
|
||||
struct Index3 {
|
||||
uint8 m; // row
|
||||
// columns
|
||||
uint8 k1;
|
||||
uint8 k2;
|
||||
uint8 k3;
|
||||
};
|
||||
|
||||
// Index of one element.
|
||||
struct Index {
|
||||
uint8 m;
|
||||
uint8 k;
|
||||
};
|
||||
|
||||
SparseSlice(int nrows, int ncols, int bsize)
|
||||
: num_rows(nrows), num_cols(ncols), block_size(bsize) {
|
||||
DCHECK_LE(nrows, 256);
|
||||
DCHECK_LE(block_size, 256);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PrefetchBlockT1(const T& tensor, int si, int ei, int sj, int ej) {
|
||||
for (int i = si; i < ei; ++i) {
|
||||
for (int j = sj; j < ej; j = j + 16) {
|
||||
port::prefetch<port::PREFETCH_HINT_T1>(&tensor(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Initializes the slice with data starting at mat(0, col_offset) and with
|
||||
// size (num_rows, num_cols).
|
||||
// If Transpose is true, implicitly transposes mat.
|
||||
template <bool Transpose = false>
|
||||
void Initialize(const ConstMatrixMap& mat, int col_offset);
|
||||
|
||||
struct Block {
|
||||
Block(int sm, int em, int sk, int ek, int sn, int en)
|
||||
: startm(sm), endm(em), startk(sk), endk(ek), startn(sn), endn(en) {}
|
||||
void Clear();
|
||||
|
||||
int startm;
|
||||
int endm;
|
||||
int startk;
|
||||
int endk;
|
||||
int startn;
|
||||
int endn;
|
||||
// See comments above.
|
||||
std::vector<int> index3_offset;
|
||||
std::vector<Index3> index3;
|
||||
std::vector<float> data3;
|
||||
|
||||
// See comments above. Similar to "index3" except that each element in "index"
|
||||
// corresponds to one element in data.
|
||||
std::vector<int> index_offset;
|
||||
std::vector<Index> index;
|
||||
std::vector<float> data;
|
||||
|
||||
// Number of rows and columns for the slice.
|
||||
const int num_rows;
|
||||
const int num_cols;
|
||||
|
||||
// Block size used to initialize from a matrix.
|
||||
const int block_size;
|
||||
};
|
||||
|
||||
bool NextBlock(const int Bm, const int Bk, const int Bn, const int m_start,
|
||||
const int m, const int k, const int n, const Block& b,
|
||||
Block* next) {
|
||||
*next = b;
|
||||
if (b.endk < k) {
|
||||
next->startk = b.endk;
|
||||
next->endk = std::min(b.endk + Bk, k);
|
||||
} else {
|
||||
next->startk = 0;
|
||||
next->endk = std::min(Bk, k);
|
||||
if (b.endm < m) {
|
||||
next->startm = b.endm;
|
||||
next->endm = std::min(b.endm + Bm, m);
|
||||
} else {
|
||||
next->startm = m_start;
|
||||
next->endm = std::min(m_start + Bm, m);
|
||||
next->startn = b.endn;
|
||||
next->endn = std::min(b.endn + Bn, n);
|
||||
template <bool Transpose>
|
||||
void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
|
||||
const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
|
||||
const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
|
||||
DCHECK_LE(num_rows, mat_rows);
|
||||
DCHECK_LE(num_cols + col_offset, mat_cols);
|
||||
|
||||
int num_blocks = (num_cols + block_size - 1) / block_size;
|
||||
int mat_size = num_rows * num_cols;
|
||||
|
||||
index3_offset.reserve(num_blocks);
|
||||
data3.reserve(mat_size);
|
||||
index3.reserve(mat_size / 3);
|
||||
|
||||
index_offset.reserve(num_blocks);
|
||||
data.reserve(num_blocks * num_rows * 2);
|
||||
index.reserve(num_blocks * num_rows * 2);
|
||||
|
||||
Index3 idx3;
|
||||
Index idx;
|
||||
int data3_size = 0;
|
||||
for (int i = 0; i < num_blocks; ++i) {
|
||||
int num_block_cols =
|
||||
std::min(block_size, num_cols - block_size * (num_blocks - 1));
|
||||
for (int row = 0; row < num_rows; ++row) {
|
||||
idx3.m = static_cast<uint8>(row);
|
||||
const float* start =
|
||||
Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
|
||||
const float* curr = start;
|
||||
const int stride = Transpose ? mat.dimension(1) : 1;
|
||||
const float* end = start + stride * num_block_cols;
|
||||
uint8 k = 0;
|
||||
#define NEXT_ELEM \
|
||||
curr += stride; \
|
||||
++k;
|
||||
while (true) {
|
||||
while (curr < end && (*curr == 0)) {
|
||||
NEXT_ELEM;
|
||||
}
|
||||
if (curr >= end) break;
|
||||
idx3.k1 = k;
|
||||
data3.push_back(*curr);
|
||||
NEXT_ELEM;
|
||||
|
||||
while (curr < end && (*curr == 0)) {
|
||||
NEXT_ELEM;
|
||||
}
|
||||
if (curr >= end) break;
|
||||
idx3.k2 = k;
|
||||
data3.push_back(*curr);
|
||||
NEXT_ELEM;
|
||||
|
||||
while (curr < end && (*curr == 0)) {
|
||||
NEXT_ELEM;
|
||||
}
|
||||
if (curr >= end) break;
|
||||
idx3.k3 = k;
|
||||
data3.push_back(*curr);
|
||||
NEXT_ELEM;
|
||||
index3.push_back(idx3);
|
||||
#undef NEXT_ELEM
|
||||
}
|
||||
int num_inserted_mod = data3.size() % 3;
|
||||
// Move some elements to index and data if needed.
|
||||
data3_size = data3.size() - num_inserted_mod;
|
||||
idx.m = idx3.m;
|
||||
switch (num_inserted_mod) {
|
||||
case 2:
|
||||
idx.k = idx3.k2;
|
||||
data.push_back(data3[data3_size + 1]);
|
||||
index.push_back(idx);
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case 1:
|
||||
idx.k = idx3.k1;
|
||||
data.push_back(data3[data3_size]);
|
||||
index.push_back(idx);
|
||||
data3.resize(data3_size);
|
||||
}
|
||||
}
|
||||
col_offset += block_size;
|
||||
index3_offset.push_back(index3.size());
|
||||
index_offset.push_back(index.size());
|
||||
}
|
||||
DCHECK_EQ(index3_offset.size(), num_blocks);
|
||||
DCHECK_EQ(index_offset.size(), num_blocks);
|
||||
DCHECK_EQ(3 * index3.size(), data3.size());
|
||||
DCHECK_EQ(index.size(), data.size());
|
||||
}
|
||||
|
||||
void SparseSlice::Clear() {
|
||||
index3_offset.clear();
|
||||
index3.clear();
|
||||
data3.clear();
|
||||
index_offset.clear();
|
||||
index.clear();
|
||||
data.clear();
|
||||
}
|
||||
|
||||
#define SCALAR_MULADD(a, inp, out) *out++ += *a * *inp++;
|
||||
|
||||
#define SCALAR_MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \
|
||||
*out++ += *a1 * *inp1++ + *a2 * *inp2++ + *a3 * *inp3++;
|
||||
|
||||
typedef Eigen::internal::packet_traits<float>::type Packet;
|
||||
static const int kNumOperands = (sizeof(Packet) / sizeof(float));
|
||||
#define LOAD(x) Eigen::internal::pload<Packet>(x);
|
||||
#define STORE(x, y) Eigen::internal::pstore<float>(x, y);
|
||||
#define LOAD_SCALAR(x, y) const auto y = Eigen::internal::pload1<Packet>(x);
|
||||
#define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
|
||||
|
||||
// Vectorized version of SCALAR_MULADD.
|
||||
#define MULADD(a, inp, out) \
|
||||
do { \
|
||||
const auto b = LOAD(inp); \
|
||||
inp += kNumOperands; \
|
||||
auto c = LOAD(out); \
|
||||
FMA(a, b, c, c); \
|
||||
STORE(out, c); \
|
||||
out += kNumOperands; \
|
||||
} while (false)
|
||||
|
||||
// Vectorized version of SCALAR_MULADD3WAY.
|
||||
#define MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \
|
||||
do { \
|
||||
auto c = LOAD(out); \
|
||||
const auto b1 = LOAD(inp1); \
|
||||
inp1 += kNumOperands; \
|
||||
const auto b2 = LOAD(inp2); \
|
||||
inp2 += kNumOperands; \
|
||||
const auto b3 = LOAD(inp3); \
|
||||
inp3 += kNumOperands; \
|
||||
FMA(a1, b1, c, c); \
|
||||
FMA(a2, b2, c, c); \
|
||||
FMA(a3, b3, c, c); \
|
||||
STORE(out, c); \
|
||||
out += kNumOperands; \
|
||||
} while (false)
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
// Unroll MULADD3WAY for two iterations
|
||||
#define MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out) \
|
||||
do { \
|
||||
auto c1 = LOAD(out); \
|
||||
const auto b1 = LOAD(inp1); \
|
||||
const auto b2 = LOAD(inp2); \
|
||||
const auto b3 = LOAD(inp3); \
|
||||
\
|
||||
auto c2 = LOAD(out + kNumOperands); \
|
||||
const auto b4 = LOAD(inp1 + kNumOperands); \
|
||||
const auto b5 = LOAD(inp2 + kNumOperands); \
|
||||
const auto b6 = LOAD(inp3 + kNumOperands); \
|
||||
\
|
||||
FMA(a1, b1, c1, c1); \
|
||||
FMA(a1, b4, c2, c2); \
|
||||
FMA(a2, b2, c1, c1); \
|
||||
FMA(a2, b5, c2, c2); \
|
||||
FMA(a3, b3, c1, c1); \
|
||||
FMA(a3, b6, c2, c2); \
|
||||
STORE(out, c1); \
|
||||
STORE(out + kNumOperands, c2); \
|
||||
out += 2 * kNumOperands; \
|
||||
inp1 += 2 * kNumOperands; \
|
||||
inp2 += 2 * kNumOperands; \
|
||||
inp3 += 2 * kNumOperands; \
|
||||
} while (false)
|
||||
// Further unroll MULADD3WAY.
|
||||
#define MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out) \
|
||||
MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); \
|
||||
MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out);
|
||||
#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \
|
||||
MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
|
||||
MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
|
||||
MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
|
||||
MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out);
|
||||
#else
|
||||
#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \
|
||||
for (int __i = 0; __i < 128 / (4 * kNumOperands); ++__i) { \
|
||||
MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
|
||||
MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
|
||||
MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
|
||||
MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
|
||||
}
|
||||
#endif
|
||||
|
||||
// Computes product of "left_slices" with "num_cols" columns of "right", and
|
||||
// stores the output in *"output".
|
||||
// Note that left_slices is a list of SparseSlices, which are conceptually
|
||||
// assumed to be concatenated along the column dimension. Also each SparseSlice
|
||||
// is encoded as a list of blocks with upto N columns. See SparseSlice for more
|
||||
// details.
|
||||
template <int Cols>
|
||||
inline void GEPP(const std::vector<SparseSlice*>& left_slices,
|
||||
const ConstMatrixMap& right, const int num_cols,
|
||||
Matrix* output) {
|
||||
const int cols = (Cols == -1) ? num_cols : Cols;
|
||||
DCHECK_EQ(num_cols, cols);
|
||||
const int right_num_cols = right.dimension(1);
|
||||
const int output_num_cols = output->dimension(1);
|
||||
const int cols_mod = cols % kNumOperands;
|
||||
int k_offset = 0;
|
||||
// Pre-compute pointers for output matrix.
|
||||
float* out_ptrs[M];
|
||||
float* const out_start = &(*output)(0, 0);
|
||||
for (int j = 0; j < M; ++j) {
|
||||
out_ptrs[j] = out_start + output_num_cols * j;
|
||||
}
|
||||
for (const auto* left_slice : left_slices) {
|
||||
const auto& left = *left_slice;
|
||||
const float* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr;
|
||||
const float* data = (left.data.size() > 0) ? &left.data[0] : nullptr;
|
||||
const int num_blocks = left.index3_offset.size();
|
||||
int begin3 = 0;
|
||||
int begin = 0;
|
||||
for (int i = 0; i < num_blocks; ++i) {
|
||||
// Pre-compute pointers for right matrix
|
||||
const float* right_ptrs[K];
|
||||
const float* const right_start = &right(k_offset, 0);
|
||||
DCHECK_LT(k_offset, right.dimension(0));
|
||||
for (int j = 0; j < K; ++j) {
|
||||
right_ptrs[j] = right_start + right_num_cols * j;
|
||||
}
|
||||
|
||||
const int end3 = left.index3_offset[i];
|
||||
int j = begin3;
|
||||
// Loop unrolled for 2 iterations.
|
||||
for (; j + 1 < end3; j += 2) {
|
||||
const float* sl1 = data3++;
|
||||
LOAD_SCALAR(sl1, l1);
|
||||
const float* sl2 = data3++;
|
||||
LOAD_SCALAR(sl2, l2);
|
||||
const float* sl3 = data3++;
|
||||
LOAD_SCALAR(sl3, l3);
|
||||
const float* nsl1 = data3++;
|
||||
LOAD_SCALAR(nsl1, nl1);
|
||||
const float* nsl2 = data3++;
|
||||
LOAD_SCALAR(nsl2, nl2);
|
||||
const float* nsl3 = data3++;
|
||||
LOAD_SCALAR(nsl3, nl3);
|
||||
const SparseSlice::Index3& index = left.index3[j];
|
||||
const SparseSlice::Index3& nindex = left.index3[j + 1];
|
||||
float* out = out_ptrs[index.m];
|
||||
float* nout = out_ptrs[nindex.m];
|
||||
const float* r1 = right_ptrs[index.k1];
|
||||
const float* r2 = right_ptrs[index.k2];
|
||||
const float* r3 = right_ptrs[index.k3];
|
||||
const float* nr1 = right_ptrs[nindex.k1];
|
||||
const float* nr2 = right_ptrs[nindex.k2];
|
||||
const float* nr3 = right_ptrs[nindex.k3];
|
||||
if (cols == 128) {
|
||||
MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out);
|
||||
MULADD3WAY_128(nl1, nl2, nl3, nr1, nr2, nr3, nout);
|
||||
} else {
|
||||
for (int n = 0; n < cols / kNumOperands; ++n) {
|
||||
MULADD3WAY(l1, l2, l3, r1, r2, r3, out);
|
||||
MULADD3WAY(nl1, nl2, nl3, nr1, nr2, nr3, nout);
|
||||
}
|
||||
for (int k = 0; k < cols_mod; ++k) {
|
||||
SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out);
|
||||
SCALAR_MULADD3WAY(nsl1, nsl2, nsl3, nr1, nr2, nr3, nout);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (j < end3) {
|
||||
const float* sl1 = data3++;
|
||||
LOAD_SCALAR(sl1, l1);
|
||||
const float* sl2 = data3++;
|
||||
LOAD_SCALAR(sl2, l2);
|
||||
const float* sl3 = data3++;
|
||||
LOAD_SCALAR(sl3, l3);
|
||||
const SparseSlice::Index3& index = left.index3[j];
|
||||
float* out = out_ptrs[index.m];
|
||||
const float* r1 = right_ptrs[index.k1];
|
||||
const float* r2 = right_ptrs[index.k2];
|
||||
const float* r3 = right_ptrs[index.k3];
|
||||
if (cols == 128) {
|
||||
MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out);
|
||||
} else {
|
||||
for (int n = 0; n < cols / kNumOperands; ++n) {
|
||||
MULADD3WAY(l1, l2, l3, r1, r2, r3, out);
|
||||
}
|
||||
for (int k = 0; k < cols_mod; ++k) {
|
||||
SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
begin3 = end3;
|
||||
int end = left.index_offset[i];
|
||||
for (int j = begin; j < end; ++j) {
|
||||
const float* sl = data++;
|
||||
LOAD_SCALAR(sl, l);
|
||||
const SparseSlice::Index& index = left.index[j];
|
||||
const float* r = right_ptrs[index.k];
|
||||
float* out = out_ptrs[index.m];
|
||||
for (int n = 0; n < cols / kNumOperands; ++n) {
|
||||
MULADD(l, r, out);
|
||||
}
|
||||
for (int k = 0; k < cols_mod; ++k) {
|
||||
SCALAR_MULADD(sl, r, out);
|
||||
}
|
||||
}
|
||||
k_offset += left.block_size;
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
return next->startn == next->endn;
|
||||
}
|
||||
|
||||
#undef SCALAR_MULADD
|
||||
#undef SCALAR_MULADD3WAY
|
||||
#undef LOAD
|
||||
#undef STORE
|
||||
#undef LOAD_SCALAR
|
||||
#undef FMA
|
||||
#undef MULADD
|
||||
#undef MULADD3WAY
|
||||
#undef MULADD3WAY_16
|
||||
#undef MULADD3WAY_32
|
||||
#undef MULADD3WAY_128
|
||||
|
||||
} // namespace
|
||||
|
||||
class SparseMatMulOp : public OpKernel {
|
||||
public:
|
||||
explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -80,7 +436,6 @@ class SparseMatMulOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& a = ctx->input(0);
|
||||
const Tensor& b = ctx->input(1);
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
|
||||
errors::InvalidArgument("a is not a matrix"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
|
||||
@ -115,10 +470,11 @@ class SparseMatMulOp : public OpKernel {
|
||||
left.contract(right_mat, dim_pair);
|
||||
return;
|
||||
}
|
||||
typedef Eigen::Tensor<float, 2, Eigen::RowMajor> Matrix;
|
||||
std::unique_ptr<Matrix> right_tr_mat;
|
||||
std::unique_ptr<TTypes<float>::ConstMatrix> right_tr_map;
|
||||
if (transpose_b_) {
|
||||
// TODO(agarwal): avoid transposing the matrix here and directly handle
|
||||
// transpose in CreateDenseSlices.
|
||||
right_tr_mat.reset(new Matrix(k, n));
|
||||
Eigen::array<int, 2> perm({1, 0});
|
||||
right_tr_mat->device(ctx->template eigen_device<CPUDevice>()) =
|
||||
@ -129,56 +485,75 @@ class SparseMatMulOp : public OpKernel {
|
||||
TTypes<float>::ConstMatrix& right =
|
||||
transpose_b_ ? *right_tr_map : right_mat;
|
||||
|
||||
const bool transpose_a = transpose_a_;
|
||||
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>,
|
||||
Eigen::Unaligned> TensorMap;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const float, 1, Eigen::RowMajor>,
|
||||
Eigen::Unaligned> ConstTensorMap;
|
||||
typedef Eigen::DSizes<Eigen::DenseIndex, 1> DSizes;
|
||||
const int Bm = 16;
|
||||
const int Bk = 16;
|
||||
const int Bn = 1024;
|
||||
|
||||
auto work_shard = [m, n, k, transpose_a, Bm, Bk, Bn, &left, &right, &out](
|
||||
int64 start64, int64 end64) {
|
||||
const int start = static_cast<int>(start64);
|
||||
const int end = static_cast<int>(end64);
|
||||
Block curr(start, std::min(start + Bm, end), 0, std::min(Bk, k), 0,
|
||||
std::min(Bn, n));
|
||||
Block next(curr);
|
||||
bool done = false;
|
||||
for (int i = start; i < end; ++i) {
|
||||
out.chip<0>(i).setZero();
|
||||
}
|
||||
while (true) {
|
||||
done = NextBlock(Bm, Bk, Bn, start, end, k, n, curr, &next);
|
||||
|
||||
PrefetchBlockT1(right, curr.startk, curr.endk, curr.startn, curr.endn);
|
||||
|
||||
// Process current block
|
||||
for (int i = curr.startm; i < curr.endm; ++i) {
|
||||
PrefetchBlockNTA(left, i, i + 1, curr.startk, curr.endk);
|
||||
PrefetchBlockNTA(out, i, i + 1, curr.startn, curr.endn);
|
||||
DSizes out_slice_shape(curr.endn - curr.startn);
|
||||
TensorMap out_i(&out(i, curr.startn), out_slice_shape);
|
||||
for (int j = curr.startk; j < curr.endk; ++j) {
|
||||
const float l = transpose_a ? left(j, i) : left(i, j);
|
||||
if (l == 0) continue;
|
||||
ConstTensorMap right_j(&right(j, curr.startn), out_slice_shape);
|
||||
out_i += right_j * l;
|
||||
}
|
||||
}
|
||||
if (done) break;
|
||||
curr = next;
|
||||
}
|
||||
};
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, m, 2 * k * n,
|
||||
work_shard);
|
||||
SparseMatMul(left, right, transpose_a_,
|
||||
ctx->device()->tensorflow_cpu_worker_threads(), &out);
|
||||
}
|
||||
|
||||
private:
|
||||
// Perform matrix multiplication of "left" and "right", and store the result
|
||||
// in *"ouptut".
|
||||
static inline void SparseMatMul(
|
||||
const ConstMatrixMap& left, const ConstMatrixMap& right,
|
||||
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
|
||||
MatrixMap* output);
|
||||
|
||||
// Computes multiplication of left and num_cols columns of right, and stores
|
||||
// the output block in *"output" at offsets "output_row_offset" and
|
||||
// "output_col_offset". If assign is true, assigns the value to that block,
|
||||
// else adds the values to the existing values.
|
||||
static inline void ComputeOutputBlock(const std::vector<SparseSlice*>& left,
|
||||
const ConstMatrixMap& right,
|
||||
const int num_cols,
|
||||
int output_row_offset,
|
||||
int output_col_offset, bool assign,
|
||||
MatrixMap* output);
|
||||
|
||||
// Encodes "mat" using a sparse representation and stores that in
|
||||
// "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
|
||||
// "slice_num_cols", each grid element is converted into a SparseSlice and
|
||||
// stored in mat_slices. "slice_block_size" is used to perform futher column
|
||||
// blocking of each slice.
|
||||
static inline BlockingCounter* CreateSparseSlices(
|
||||
const ConstMatrixMap& mat, bool transpose, int slice_num_rows,
|
||||
int slice_block_size, int slice_num_cols,
|
||||
std::vector<std::vector<SparseSlice*>>* mat_slices,
|
||||
const DeviceBase::CpuWorkerThreads* thread_pool);
|
||||
|
||||
// This function chops "mat" along column dimension into pieces with at most N
|
||||
// columns, and concatenates the pieces one after the other in "buffer". It
|
||||
// returns the list of the pieces in "slices". It returns a BlockingCounter
|
||||
// which should be used to wait for the shuffle operations to complete.
|
||||
static inline BlockingCounter* CreateDenseSlices(
|
||||
const ConstMatrixMap& mat, int row_start, int num_rows, int col_start,
|
||||
int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
|
||||
Matrix* buffer, std::vector<ConstMatrixMap*>* slices);
|
||||
|
||||
// Helper function for CreateDenseSlices to move the data around. It returns a
|
||||
// BlockingCounter which should be used to wait for the shuffle operations to
|
||||
// complete.
|
||||
static inline BlockingCounter* ShuffleMatrix(
|
||||
const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows,
|
||||
int slice_col_start, int slice_num_cols, const int N,
|
||||
const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer);
|
||||
|
||||
// Helper function for CreateDenseSlices to create slices.
|
||||
static inline void SliceMatrix(const Matrix& mat, const int num_rows,
|
||||
const int num_slices,
|
||||
std::vector<ConstMatrixMap*>* slices);
|
||||
|
||||
// Heuristics to compute various block sizes.
|
||||
// KR, NR: block sizes for "right". We run blocking iterations that operate on
|
||||
// matrices with at most this size.
|
||||
// KL: grid size along the column dimension used while encoding left.
|
||||
// IB, JB: number of left and right slices to multiply together. This is used
|
||||
// for ordering different ComputeBlockOutput operations inside each blocking
|
||||
// iteration so as to potentially reduce the working set size.
|
||||
static inline void ComputeBlockSizes(const ConstMatrixMap& left,
|
||||
const ConstMatrixMap& right,
|
||||
bool transpose_left, int num_threads,
|
||||
int* KR, int* NR, int* KL, int* JB,
|
||||
int* IB);
|
||||
|
||||
bool transpose_a_;
|
||||
bool transpose_b_;
|
||||
bool a_is_sparse_;
|
||||
@ -186,6 +561,329 @@ class SparseMatMulOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
|
||||
};
|
||||
|
||||
inline void SparseMatMulOp::ComputeOutputBlock(
|
||||
const std::vector<SparseSlice*>& left, const ConstMatrixMap& right,
|
||||
const int num_cols, int output_row_offset, int output_col_offset,
|
||||
bool assign, MatrixMap* output) {
|
||||
const int num_rows = left[0]->num_rows;
|
||||
const int rhs_num_cols = right.dimension(1);
|
||||
DCHECK_LE(num_cols, rhs_num_cols);
|
||||
Matrix out(num_rows, rhs_num_cols);
|
||||
out.setZero();
|
||||
if (num_cols == N) {
|
||||
GEPP<N>(left, right, num_cols, &out);
|
||||
} else {
|
||||
GEPP<-1>(left, right, num_cols, &out);
|
||||
}
|
||||
if (!assign) {
|
||||
const Eigen::array<int, 2> begin = {output_row_offset, output_col_offset};
|
||||
const Eigen::array<int, 2> sizes = {num_rows, num_cols};
|
||||
if (num_cols == rhs_num_cols) {
|
||||
output->slice(begin, sizes) += out;
|
||||
} else {
|
||||
static const Eigen::array<int, 2> zero = {0, 0};
|
||||
output->slice(begin, sizes) += out.slice(zero, sizes);
|
||||
}
|
||||
} else {
|
||||
// output->slice(begin, sizes) = out.slice(zero, sizes), implemented
|
||||
// using memcpy.
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
memcpy(&(*output)(output_row_offset + i, output_col_offset), &out(i, 0),
|
||||
num_cols * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
|
||||
const ConstMatrixMap& mat, bool transpose, int slice_num_rows,
|
||||
int slice_block_size, int slice_num_cols,
|
||||
std::vector<std::vector<SparseSlice*>>* mat_slices,
|
||||
const DeviceBase::CpuWorkerThreads* thread_pool) {
|
||||
const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
|
||||
const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
|
||||
const int num_slices_dim0 =
|
||||
std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
|
||||
const int num_slices_dim1 =
|
||||
std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
|
||||
mat_slices->resize(num_slices_dim0);
|
||||
BlockingCounter* counter =
|
||||
new BlockingCounter(num_slices_dim0 * num_slices_dim1);
|
||||
auto work = [counter, transpose](SparseSlice* sparse_slice,
|
||||
ConstMatrixMap* slice, int col_offset) {
|
||||
if (transpose) {
|
||||
sparse_slice->Initialize<true>(*slice, col_offset);
|
||||
} else {
|
||||
sparse_slice->Initialize<false>(*slice, col_offset);
|
||||
}
|
||||
delete slice;
|
||||
counter->DecrementCount();
|
||||
};
|
||||
for (int i = 0; i < num_slices_dim0; ++i) {
|
||||
(*mat_slices)[i].resize(num_slices_dim1);
|
||||
int num_rows =
|
||||
std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
|
||||
for (int j = 0; j < num_slices_dim1; ++j) {
|
||||
int num_cols =
|
||||
std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
|
||||
ConstMatrixMap* slice = nullptr;
|
||||
if (transpose) {
|
||||
slice =
|
||||
new ConstMatrixMap(&mat(0, i * slice_num_rows), mat.dimensions());
|
||||
} else {
|
||||
DSizes d(num_rows, mat_num_cols);
|
||||
slice = new ConstMatrixMap(&mat(i * slice_num_rows, 0), d);
|
||||
}
|
||||
SparseSlice* sparse_slice =
|
||||
new SparseSlice(num_rows, num_cols, slice_block_size);
|
||||
(*mat_slices)[i][j] = sparse_slice;
|
||||
thread_pool->workers->Schedule(
|
||||
std::bind(work, sparse_slice, slice, slice_num_cols * j));
|
||||
}
|
||||
}
|
||||
return counter;
|
||||
}
|
||||
|
||||
inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
|
||||
const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows,
|
||||
int slice_col_start, int slice_num_cols, const int N,
|
||||
const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer) {
|
||||
int num_threads = std::min(thread_pool->num_threads, 16);
|
||||
BlockingCounter* counter = new BlockingCounter(num_threads);
|
||||
DCHECK_EQ(N, buffer->dimension(1));
|
||||
auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
|
||||
slice_num_cols, N, buffer, counter](int s, int e) {
|
||||
const int row_start = s % slice_num_rows + slice_row_start;
|
||||
const int col_start = s / slice_num_rows * N + slice_col_start;
|
||||
float* out_start = &(*buffer)(s, 0);
|
||||
const float* input_start = &mat(row_start, col_start);
|
||||
const float* input_end = &mat(slice_row_start + slice_num_rows - 1,
|
||||
slice_col_start + slice_num_cols - 1);
|
||||
const int mat_num_cols = mat.dimension(1);
|
||||
const int row_slice_size = slice_num_rows * mat_num_cols;
|
||||
|
||||
const int aligned_end = slice_num_cols / N * slice_num_rows;
|
||||
const int e1 = std::min(e, aligned_end);
|
||||
while (s < e1) {
|
||||
memcpy(out_start, input_start, N * sizeof(float));
|
||||
out_start += N;
|
||||
input_start += mat_num_cols;
|
||||
if (input_start > input_end) {
|
||||
input_start = input_start - row_slice_size + N;
|
||||
}
|
||||
++s;
|
||||
}
|
||||
int s1 = std::max(s, aligned_end);
|
||||
const int copy_num_cols = slice_num_cols % N;
|
||||
while (s1 < e) {
|
||||
memcpy(out_start, input_start, copy_num_cols * sizeof(float));
|
||||
out_start += N;
|
||||
input_start += mat_num_cols;
|
||||
++s1;
|
||||
}
|
||||
if (counter) counter->DecrementCount();
|
||||
};
|
||||
|
||||
int start = 0;
|
||||
int end = 0;
|
||||
int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
|
||||
DCHECK_LE(num_out_rows, buffer->dimension(0));
|
||||
for (int i = std::max(1, num_threads); i > 0; --i) {
|
||||
end = start + num_out_rows / i;
|
||||
thread_pool->workers->Schedule(std::bind(shuffle_work, start, end));
|
||||
num_out_rows -= (end - start);
|
||||
start = end;
|
||||
}
|
||||
return counter;
|
||||
}
|
||||
|
||||
inline void SparseMatMulOp::SliceMatrix(const Matrix& mat, const int num_rows,
|
||||
const int num_slices,
|
||||
std::vector<ConstMatrixMap*>* slices) {
|
||||
slices->resize(num_slices);
|
||||
DSizes d(num_rows, mat.dimension(1));
|
||||
DCHECK_LE(num_rows * num_slices, mat.dimension(0));
|
||||
for (int i = 0; i < num_slices; ++i) {
|
||||
(*slices)[i] = new ConstMatrixMap(&mat(i * num_rows, 0), d);
|
||||
}
|
||||
}
|
||||
|
||||
inline BlockingCounter* SparseMatMulOp::CreateDenseSlices(
|
||||
const ConstMatrixMap& mat, int row_start, int num_rows, int col_start,
|
||||
int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
|
||||
Matrix* buffer, std::vector<ConstMatrixMap*>* slices) {
|
||||
BlockingCounter* shuffle_counter = ShuffleMatrix(
|
||||
mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer);
|
||||
const int num_slices = (num_cols + N - 1) / N;
|
||||
SliceMatrix(*buffer, num_rows, num_slices, slices);
|
||||
return shuffle_counter;
|
||||
}
|
||||
|
||||
inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left,
|
||||
const ConstMatrixMap& right,
|
||||
bool transpose_left,
|
||||
int num_threads, int* KR, int* NR,
|
||||
int* KL, int* JB, int* IB) {
|
||||
// Heuristics for calculating block sizes
|
||||
// Assume two hyperthreads per core.
|
||||
const int est_num_cores = std::max(1, (num_threads + 1) / 2);
|
||||
// Use block of rhs with at most 128K floats per core.
|
||||
const int mem = est_num_cores * 128 * 1024;
|
||||
*KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
|
||||
*NR = right.dimension(1);
|
||||
if (*KR * *NR > mem) {
|
||||
// 4096 may be enough to ammortize the cost of writes.
|
||||
*KR = std::min<int>(*KR, 4096);
|
||||
}
|
||||
// Use sizes that are multiples of K and 256.
|
||||
*KR = std::max(1, *KR / K) * K;
|
||||
*NR = std::max(1, *NR / 256) * 256;
|
||||
if (*KR * *NR > mem) {
|
||||
*NR = mem / *KR;
|
||||
}
|
||||
*NR = std::max(1, *NR / 256) * 256;
|
||||
|
||||
const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
|
||||
const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
|
||||
for (*KL = 1024; *KL > K; *KL /= 2) {
|
||||
if (*KR % *KL == 0 &&
|
||||
std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
DCHECK_EQ(*KL % K, 0);
|
||||
DCHECK_GE(*KR, *KL);
|
||||
if (*KR < right.dimension(0)) {
|
||||
CHECK_EQ(*KR % *KL, 0);
|
||||
}
|
||||
|
||||
*JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
|
||||
*IB = 8 * *JB;
|
||||
DCHECK_EQ(N * sizeof(float) % 64, 0);
|
||||
}
|
||||
|
||||
// Here is a an overview of the SparseMatMul code. Note that we assume that the
|
||||
// left matrix is sparse.
|
||||
//
|
||||
// The matrix "left" is divided into a grid with blocksize of (M, KL). Each
|
||||
// block is encoded as a SparseSlice. These grid elements are stored as
|
||||
// std::vector<std::vector<SparseSlice>>. Each element of the outer vector
|
||||
// represents M rows of the left matrix. Lets call these elements l_i and lets
|
||||
// call each element of the inner vector L_mk.
|
||||
//
|
||||
// The matrix "right" is divided into a grid with block size KR * NR. Lets
|
||||
// denote the blocks on the right as R_kn. Note that we ensure that KL divides
|
||||
// KR so that for each element R_kn, we don't need to multiply it with any
|
||||
// partial L_mk blocks.
|
||||
//
|
||||
// We then multiply each right side block R_kn with the full "left" matrix and
|
||||
// update the output. These iterations are run sequentially since R_kn are
|
||||
// packed into the same underlying temporary buffer.
|
||||
//
|
||||
// In each iteration we do the following:
|
||||
// 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
|
||||
// (=128) columns and then concatenating these slices into a buffer. This is
|
||||
// done so that each slice r_j of R_kn is stored contiguously in memory. Note
|
||||
// that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
|
||||
// buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
|
||||
// 2. For each (l_i, r_j), we compute the inner product using the GEPP function
|
||||
// and update the output block o_ij. These calls are further blocked to
|
||||
// reduce the working set size. In each iteration we take IB elements from
|
||||
// {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
|
||||
inline void SparseMatMulOp::SparseMatMul(
|
||||
const ConstMatrixMap& left, const ConstMatrixMap& right,
|
||||
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
|
||||
MatrixMap* output) {
|
||||
const int num_threads = thread_pool->num_threads;
|
||||
int KR, NR, KL, JB, IB;
|
||||
ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
|
||||
&JB, &IB);
|
||||
|
||||
// Slice the left matrix
|
||||
std::vector<std::vector<SparseSlice*>> left_slices;
|
||||
std::unique_ptr<BlockingCounter> sparse_slice_counter;
|
||||
sparse_slice_counter.reset(
|
||||
CreateSparseSlices(ConstMatrixMap(left.data(), left.dimensions()),
|
||||
transpose_left, M, K, KL, &left_slices, thread_pool));
|
||||
const int num_left_slices = left_slices.size();
|
||||
|
||||
const int right_dim0 = right.dimension(0);
|
||||
const int right_dim1 = right.dimension(1);
|
||||
// Allocate buffer for storing slices of right matrix.
|
||||
// Note buffer needs enough space to hold atmost a KR * NR matrix since that
|
||||
// is the block size per iteration.
|
||||
const int buffer_num_rows =
|
||||
std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N;
|
||||
Matrix buffer(buffer_num_rows, N);
|
||||
std::vector<ConstMatrixMap*> right_slices;
|
||||
|
||||
std::vector<SparseSlice*> block_left_slices;
|
||||
std::vector<std::function<void(void)>> tasks;
|
||||
// Number of blocks based on block sizes of KR * NR.
|
||||
const int num_k_blocks = (right_dim0 + KR - 1) / KR;
|
||||
const int num_n_blocks = (right_dim1 + NR - 1) / NR;
|
||||
std::unique_ptr<BlockingCounter> dense_slice_counter;
|
||||
|
||||
for (int nb = 0; nb < num_n_blocks; ++nb) {
|
||||
const int right_num_cols =
|
||||
std::min(NR, static_cast<int>(right_dim1 - NR * nb));
|
||||
for (int kb = 0; kb < num_k_blocks; ++kb) {
|
||||
const int right_num_rows =
|
||||
std::min(KR, static_cast<int>(right_dim0 - KR * kb));
|
||||
dense_slice_counter.reset(CreateDenseSlices(
|
||||
right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
|
||||
&buffer, &right_slices));
|
||||
const int num_right_slices = right_slices.size();
|
||||
tasks.reserve(num_left_slices * num_right_slices);
|
||||
for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
|
||||
for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
|
||||
for (int j_inner = j_outer;
|
||||
j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
|
||||
const int num_cols = std::min(N, right_num_cols - N * j_inner);
|
||||
for (int i_inner = i_outer;
|
||||
i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
|
||||
// Figure out which left slices to use.
|
||||
block_left_slices.clear();
|
||||
int begin = kb * KR / KL;
|
||||
int end = std::min<int>((kb + 1) * KR / KL,
|
||||
(right.dimension(0) + KL - 1) / KL);
|
||||
DCHECK_LT(begin, end);
|
||||
block_left_slices.insert(block_left_slices.begin(),
|
||||
left_slices[i_inner].begin() + begin,
|
||||
left_slices[i_inner].begin() + end);
|
||||
tasks.push_back(std::bind(
|
||||
&SparseMatMulOp::ComputeOutputBlock, block_left_slices,
|
||||
std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
|
||||
N * j_inner + nb * NR, kb == 0, output));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sparse_slice_counter) {
|
||||
sparse_slice_counter->Wait();
|
||||
sparse_slice_counter.reset(nullptr);
|
||||
}
|
||||
if (dense_slice_counter) {
|
||||
dense_slice_counter->Wait();
|
||||
dense_slice_counter.reset(nullptr);
|
||||
}
|
||||
BlockingCounter bc(tasks.size());
|
||||
for (const auto& t : tasks) {
|
||||
thread_pool->workers->Schedule([&bc, &t]() {
|
||||
t();
|
||||
bc.DecrementCount();
|
||||
});
|
||||
}
|
||||
bc.Wait();
|
||||
tasks.clear();
|
||||
gtl::STLDeleteElements(&right_slices);
|
||||
right_slices.clear();
|
||||
}
|
||||
}
|
||||
for (auto& left_slice : left_slices) {
|
||||
gtl::STLDeleteElements(&left_slice);
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU),
|
||||
SparseMatMulOp);
|
||||
|
||||
|
@ -20,6 +20,8 @@ void Sparsify(Tensor* t, float sparsity) {
|
||||
for (int64 i = 0; i < N; ++i) {
|
||||
if (rnd.Uniform(K) < sparsity * K) {
|
||||
flat(i) = 0;
|
||||
} else if (flat(i) == 0) {
|
||||
flat(i) = 0.1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -86,19 +88,29 @@ static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_a,
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_SPARSE(M, K, N, S) \
|
||||
static void BM_Sparse##_##M##_##K##_##N##_##S(int iters) { \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
|
||||
std::string label = strings::Printf("%d_%d_%d_%0.2f", M, K, N, S / 100.0); \
|
||||
testing::SetLabel(label); \
|
||||
test::Benchmark("cpu", SparseMatMul(M, N, K, S / 100.0, false, false)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
#define BM_SPARSE(M, K, N, S) \
|
||||
static void BM_Sparse##_##M##_##K##_##N##_##S(int iters) { \
|
||||
testing::StopTiming(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
|
||||
std::string label; \
|
||||
if (S == 0) { \
|
||||
label = strings::Printf("%d*%d*%d_Eigen", M, K, N); \
|
||||
} else { \
|
||||
label = strings::Printf("%d*%d*%d_sparsity:%0.2f", M, K, N, S / 100.0); \
|
||||
} \
|
||||
testing::SetLabel(label); \
|
||||
testing::UseRealTime(); \
|
||||
auto g = SparseMatMul(M, N, K, S / 100.0, false, false); \
|
||||
testing::StartTiming(); \
|
||||
test::Benchmark("cpu", g).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S);
|
||||
|
||||
BM_SPARSE(2048, 2048, 2048, 0);
|
||||
BM_SPARSE(2048, 2048, 2048, 1);
|
||||
BM_SPARSE(2048, 2048, 2048, 50);
|
||||
BM_SPARSE(2048, 2048, 2048, 85);
|
||||
BM_SPARSE(2048, 2048, 2048, 99);
|
||||
|
||||
BM_SPARSE(1024, 1024, 1024, 0);
|
||||
BM_SPARSE(1024, 1024, 1024, 1);
|
||||
@ -107,28 +119,34 @@ BM_SPARSE(1024, 1024, 1024, 85);
|
||||
BM_SPARSE(256, 256, 256, 1);
|
||||
BM_SPARSE(512, 512, 512, 1);
|
||||
|
||||
#define BM_SPARSE_MULTI(M, K, N, S1, S2) \
|
||||
static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2(int iters) { \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2 * 3); \
|
||||
std::string label = strings::Printf("%d_%d_%d_%0.2f_%0.2f", M, K, N, \
|
||||
S1 / 100.0, S2 / 100.0); \
|
||||
testing::SetLabel(label); \
|
||||
test::Benchmark("cpu", MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
#define BM_SPARSE_MULTI(M, K, N, S1, S2) \
|
||||
static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2(int iters) { \
|
||||
testing::StopTiming(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2 * 3); \
|
||||
std::string label = strings::Printf("%d_%d_%d_%0.2f_%0.2f", M, K, N, \
|
||||
S1 / 100.0, S2 / 100.0); \
|
||||
testing::SetLabel(label); \
|
||||
testing::UseRealTime(); \
|
||||
auto g = MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0); \
|
||||
testing::StartTiming(); \
|
||||
test::Benchmark("cpu", g).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2);
|
||||
|
||||
BM_SPARSE_MULTI(512, 2140, 4096, 0, 82);
|
||||
BM_SPARSE_MULTI(512, 4096, 2048, 83, 83);
|
||||
BM_SPARSE_MULTI(1024, 2140, 4096, 0, 82);
|
||||
BM_SPARSE_MULTI(1024, 4096, 2048, 83, 83);
|
||||
|
||||
#define BM_SPARSE_TR(M, K, N, S, TA, TB) \
|
||||
static void BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB(int iters) { \
|
||||
testing::StopTiming(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
|
||||
std::string label = \
|
||||
strings::Printf("%d_%d_%d_%d_%d_%0.2f", M, K, N, TA, TB, S / 100.0); \
|
||||
testing::SetLabel(label); \
|
||||
test::Benchmark("cpu", SparseMatMul(M, N, K, S / 100.0, TA, TB)) \
|
||||
.Run(iters); \
|
||||
testing::UseRealTime(); \
|
||||
auto g = SparseMatMul(M, N, K, S / 100.0, TA, TB); \
|
||||
testing::StartTiming(); \
|
||||
test::Benchmark("cpu", g).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB);
|
||||
|
||||
|
@ -25,7 +25,7 @@ class StringToHashBucketOp : public OpKernel {
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<int64>();
|
||||
|
||||
for (std::size_t i = 0; i < input_flat.size(); ++i) {
|
||||
for (size_t i = 0; i < input_flat.size(); ++i) {
|
||||
const uint64 input_hash = Hash64(input_flat(i));
|
||||
const uint64 bucket_id = input_hash % num_buckets_;
|
||||
// The number of buckets is always in the positive range of int64 so is
|
||||
|
@ -17,6 +17,12 @@ bool StringToValue<int32>(const string& content, int* value) {
|
||||
return str_util::NumericParse32(content, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
bool StringToValue<string>(const string& content, string* value) {
|
||||
*value = content;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Parse a single argument by linearly searching through the command table.
|
||||
// The input format is: --argument=value.
|
||||
// Return OK if the argument is used. It store the extracted value into the
|
||||
@ -27,7 +33,7 @@ bool StringToValue<int32>(const string& content, int* value) {
|
||||
template <typename T>
|
||||
Status ParseArgument(const string& argument) {
|
||||
for (auto& command :
|
||||
internal::CommandLineFlagRegistry<int>::Instance()->commands) {
|
||||
internal::CommandLineFlagRegistry<T>::Instance()->commands) {
|
||||
string prefix = strings::StrCat("--", command.name, "=");
|
||||
if (tensorflow::StringPiece(argument).starts_with(prefix)) {
|
||||
string content = argument.substr(prefix.length());
|
||||
@ -62,6 +68,7 @@ Status ParseArgument<bool>(const string& argument) {
|
||||
return Status(error::NOT_FOUND,
|
||||
strings::StrCat("Unknown command: ", argument));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ParseCommandLineFlags(int* argc, char* argv[]) {
|
||||
@ -81,6 +88,11 @@ Status ParseCommandLineFlags(int* argc, char* argv[]) {
|
||||
if (s.ok()) {
|
||||
continue;
|
||||
}
|
||||
// Search string commands.
|
||||
s = ParseArgument<string>(argv[index]);
|
||||
if (s.ok()) {
|
||||
continue;
|
||||
}
|
||||
if (s.code() != error::NOT_FOUND) {
|
||||
return s;
|
||||
}
|
||||
|
@ -43,11 +43,14 @@ struct CommandLineFlagRegister {
|
||||
} // namespace internal
|
||||
|
||||
#define TF_DEFINE_int32(name, default_value, text) \
|
||||
TF_DEFINE_variable(int32, name, default_value, text);
|
||||
TF_DEFINE_variable(tensorflow::int32, name, default_value, text);
|
||||
|
||||
#define TF_DEFINE_bool(name, default_value, text) \
|
||||
TF_DEFINE_variable(bool, name, default_value, text);
|
||||
|
||||
#define TF_DEFINE_string(name, default_value, text) \
|
||||
TF_DEFINE_variable(string, name, default_value, text);
|
||||
|
||||
// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv.
|
||||
// Returned the number of unused arguments in *argc.
|
||||
// Return error Status if the parsing encounters errors.
|
||||
|
@ -859,10 +859,10 @@ REGISTER_OP("ListDiff")
|
||||
.Output("idx: int32")
|
||||
.Attr("T: type")
|
||||
.Doc(R"doc(
|
||||
Computes the difference between two lists of numbers.
|
||||
Computes the difference between two lists of numbers or strings.
|
||||
|
||||
Given a list `x` and a list `y`, this operation returns a list `out` that
|
||||
represents all numbers that are in `x` but not in `y`. The returned list `out`
|
||||
represents all values that are in `x` but not in `y`. The returned list `out`
|
||||
is sorted in the same order that the numbers appear in `x` (duplicates are
|
||||
preserved). This operation also returns a list `idx` that represents the
|
||||
position of each `out` element in `x`. In other words:
|
||||
|
@ -35,7 +35,14 @@ REGISTER_OP("MatrixInverse")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Doc(R"doc(
|
||||
Calculates the inverse of a square invertible matrix. Checks for invertibility.
|
||||
Calculates the inverse of a square invertible matrix.
|
||||
|
||||
The op uses the Cholesky decomposition if the matrix is symmetric positive
|
||||
definite and LU decomposition with partial pivoting otherwise.
|
||||
|
||||
If the matrix is not invertible there is no guarantee what the op does. It
|
||||
may detect the condition and raise an exception or it may simply return a
|
||||
garbage result.
|
||||
|
||||
input: Shape is `[M, M]`.
|
||||
output: Shape is `[M, M]` containing the matrix inverse of the input.
|
||||
@ -47,12 +54,19 @@ REGISTER_OP("BatchMatrixInverse")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Doc(R"doc(
|
||||
Calculates the inverse of square invertible matrices. Checks for invertibility.
|
||||
Calculates the inverse of square invertible matrices.
|
||||
|
||||
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||
form square matrices. The output is a tensor of the same shape as the input
|
||||
containing the inverse for all input submatrices `[..., :, :]`.
|
||||
|
||||
The op uses the Cholesky decomposition if the matrices are symmetric positive
|
||||
definite and LU decomposition with partial pivoting otherwise.
|
||||
|
||||
If a matrix is not invertible there is no guarantee what the op does. It
|
||||
may detect the condition and raise an exception or it may simply return a
|
||||
garbage result.
|
||||
|
||||
input: Shape is `[..., M, M]`.
|
||||
output: Shape is `[..., M, M]`.
|
||||
T: The type of values in the input and output.
|
||||
|
@ -186,8 +186,8 @@ extern void TF_SetTarget(TF_SessionOptions* options, const char* target);
|
||||
// config should be a serialized brain.ConfigProto proto.
|
||||
// If config was not parsed successfully as a ConfigProto, record the
|
||||
// error information in *status.
|
||||
extern void TF_SetConfig(TF_SessionOptions* options, const char* config,
|
||||
size_t config_len, TF_Status* status);
|
||||
extern void TF_SetConfig(TF_SessionOptions* options, const void* proto,
|
||||
size_t proto_len, TF_Status* status);
|
||||
|
||||
// Destroy an options object.
|
||||
extern void TF_DeleteSessionOptions(TF_SessionOptions*);
|
||||
|
@ -10,11 +10,11 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
|
||||
#include "google/protobuf/message_lite.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "google/protobuf/src/google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "google/protobuf/src/google/protobuf/io/zero_copy_stream_impl_lite.h"
|
||||
#include "google/protobuf/src/google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/src/google/protobuf/message_lite.h"
|
||||
|
||||
static const char* const ASSET_PREFIX = "file:///android_asset/";
|
||||
|
||||
|
30
tensorflow/examples/label_image/BUILD
Normal file
30
tensorflow/examples/label_image/BUILD
Normal file
@ -0,0 +1,30 @@
|
||||
# Description:
|
||||
# Tensorflow C++ inference example for labeling images.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_binary(
|
||||
name = "label_image",
|
||||
srcs = ["main.cc"],
|
||||
linkopts = ["-lm"],
|
||||
deps = [
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
"bin/**",
|
||||
"gen/**",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
49
tensorflow/examples/label_image/README.md
Normal file
49
tensorflow/examples/label_image/README.md
Normal file
@ -0,0 +1,49 @@
|
||||
# Tensorflow C++ Image Recognition Demo
|
||||
|
||||
This example shows how you can load a pre-trained TensorFlow network and use it
|
||||
to recognize objects in images.
|
||||
|
||||
## Description
|
||||
|
||||
This demo uses a Google Inception model to classify image files that are passed
|
||||
in on the command line. See
|
||||
[`googlenet_labels.txt`](data/googlenet_labels.txt)
|
||||
for the possible classifications, which are the 1,000 categories used in the
|
||||
Imagenet competition.
|
||||
|
||||
## To build/install/run
|
||||
|
||||
As long as you've managed to build the main TensorFlow framework, you should
|
||||
have everything you need to run this example installed already.
|
||||
|
||||
To build it, run this command:
|
||||
|
||||
```bash
|
||||
$ bazel build tensorflow/examples/label_image/...
|
||||
```
|
||||
|
||||
That should build a binary executable that you can then run like this:
|
||||
|
||||
```bash
|
||||
$ bazel-bin/tensorflow/examples/label_image/label_image
|
||||
```
|
||||
|
||||
This uses the default example image that ships with the framework, and should
|
||||
output something similar to this:
|
||||
|
||||
```
|
||||
I tensorflow/examples/label_image/main.cc:200] military uniform (866): 0.902268
|
||||
I tensorflow/examples/label_image/main.cc:200] bow tie (817): 0.05407
|
||||
I tensorflow/examples/label_image/main.cc:200] suit (794): 0.0113195
|
||||
I tensorflow/examples/label_image/main.cc:200] bulletproof vest (833): 0.0100269
|
||||
I tensorflow/examples/label_image/main.cc:200] bearskin (849): 0.00649746
|
||||
```
|
||||
In this case, we're using the default image of Admiral Grace Hopper, and you can
|
||||
see the network correctly spots she's wearing a military uniform, with a high
|
||||
score of 0.9.
|
||||
|
||||
Next, try it out on your own images by supplying the --image= argument, e.g.
|
||||
|
||||
```bash
|
||||
$ bazel-bin/tensorflow/examples/label_image/label_image --image=my_image.png
|
||||
```
|
1001
tensorflow/examples/label_image/data/googlenet_labels.txt
Normal file
1001
tensorflow/examples/label_image/data/googlenet_labels.txt
Normal file
File diff suppressed because it is too large
Load Diff
BIN
tensorflow/examples/label_image/data/grace_hopper.jpg
Normal file
BIN
tensorflow/examples/label_image/data/grace_hopper.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 60 KiB |
295
tensorflow/examples/label_image/main.cc
Normal file
295
tensorflow/examples/label_image/main.cc
Normal file
@ -0,0 +1,295 @@
|
||||
// A minimal but useful C++ example showing how to load an Imagenet-style object
|
||||
// recognition TensorFlow model, prepare input images for it, run them through
|
||||
// the graph, and interpret the results.
|
||||
//
|
||||
// It's designed to have as few dependencies and be as clear as possible, so
|
||||
// it's more verbose than it could be in production code. In particular, using
|
||||
// auto for the types of a lot of the returned values from TensorFlow calls can
|
||||
// remove a lot of boilerplate, but I find them explicit types useful in sample
|
||||
// code to make it simple to look up the classes involved.
|
||||
//
|
||||
// To use it, compile and then run in a working directory with the
|
||||
// learning/brain/tutorials/label_image/data/ folder below it, and you should
|
||||
// see the top five labels for the example Lena image output. You can then
|
||||
// customize it to use your own models or images by changing the file names at
|
||||
// the top of the main() function.
|
||||
//
|
||||
// The googlenet_graph.pb file included by default is created from Inception.
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/command_line_flags.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/tensor.h"
|
||||
|
||||
// These are all common classes it's handy to reference with no namespace.
|
||||
using tensorflow::Tensor;
|
||||
using tensorflow::Status;
|
||||
using tensorflow::string;
|
||||
using tensorflow::int32;
|
||||
|
||||
// These are the command-line flags the program can understand.
|
||||
// They define where the graph and input data is located, and what kind of
|
||||
// input the model expects. If you train your own model, or use something
|
||||
// other than GoogLeNet you'll need to update these.
|
||||
TF_DEFINE_string(image,
|
||||
"tensorflow/examples/label_image/data/grace_hopper.jpg",
|
||||
"The image to classify (JPEG or PNG).");
|
||||
TF_DEFINE_string(graph,
|
||||
"tensorflow/examples/label_image/data/googlenet_graph.pb",
|
||||
"The location of the GraphDef file containing the protobuf"
|
||||
" definition of the network.");
|
||||
TF_DEFINE_string(labels,
|
||||
"tensorflow/examples/label_image/data/googlenet_labels.txt",
|
||||
"A text file containing the labels of all the categories, one"
|
||||
" per line.");
|
||||
TF_DEFINE_int32(input_width, 224, "Width of the image the network expects.");
|
||||
TF_DEFINE_int32(input_height, 224, "Height of the image the network expects.");
|
||||
TF_DEFINE_int32(input_mean, 117, "How much to subtract from input values.");
|
||||
TF_DEFINE_int32(input_std, 1, "What to divide the input values by.");
|
||||
TF_DEFINE_string(input_layer, "input", "The name of the input node.");
|
||||
TF_DEFINE_string(output_layer, "softmax2", "The name of the output node.");
|
||||
TF_DEFINE_bool(self_test, false, "Whether to run a sanity check on the results.");
|
||||
TF_DEFINE_string(root_dir, "", "The directory at the root of the data files.");
|
||||
|
||||
// Takes a file name, and loads a list of labels from it, one per line, and
|
||||
// returns a vector of the strings. It pads with empty strings so the length
|
||||
// of the result is a multiple of 16, because our model expects that.
|
||||
Status ReadLabelsFile(string file_name, std::vector<string>* result) {
|
||||
std::ifstream file(file_name);
|
||||
result->clear();
|
||||
string line;
|
||||
while (std::getline(file, line)) {
|
||||
result->push_back(line);
|
||||
}
|
||||
const int padding = 16;
|
||||
while (result->size() % padding) {
|
||||
result->emplace_back();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Given an image file name, read in the data, try to decode it as an image,
|
||||
// resize it to the requested size, and then scale the values as desired.
|
||||
Status ReadTensorFromImageFile(string file_name, const int input_height,
|
||||
const int input_width, const float input_mean,
|
||||
const float input_std,
|
||||
std::vector<Tensor>* out_tensors) {
|
||||
tensorflow::GraphDefBuilder b;
|
||||
string input_name = "file_reader";
|
||||
string output_name = "normalized";
|
||||
tensorflow::Node* file_reader =
|
||||
tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()),
|
||||
b.opts().WithName(input_name));
|
||||
// Now try to figure out what kind of file it is and decode it.
|
||||
const int wanted_channels = 3;
|
||||
tensorflow::Node* image_reader;
|
||||
if (tensorflow::StringPiece(file_name).ends_with(".png")) {
|
||||
image_reader = tensorflow::ops::DecodePng(
|
||||
file_reader,
|
||||
b.opts().WithAttr("channels", wanted_channels).WithName("png_reader"));
|
||||
} else {
|
||||
// Assume if it's not a PNG then it must be a JPEG.
|
||||
image_reader = tensorflow::ops::DecodeJpeg(
|
||||
file_reader,
|
||||
b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader"));
|
||||
}
|
||||
// Now cast the image data to float so we can do normal math on it.
|
||||
tensorflow::Node* float_caster = tensorflow::ops::Cast(
|
||||
image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster"));
|
||||
// The convention for image ops in TensorFlow is that all images are expected
|
||||
// to be in batches, so that they're four-dimensional arrays with indices of
|
||||
// [batch, height, width, channel]. Because we only have a single image, we
|
||||
// have to add a batch dimension of 1 to the start with ExpandDims().
|
||||
tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims(
|
||||
float_caster, tensorflow::ops::Const(0, b.opts()), b.opts());
|
||||
// Bilinearly resize the image to fit the required dimensions.
|
||||
tensorflow::Node* resized = tensorflow::ops::ResizeBilinear(
|
||||
dims_expander, tensorflow::ops::Const({input_height, input_width},
|
||||
b.opts().WithName("size")),
|
||||
b.opts());
|
||||
// Subtract the mean and divide by the scale.
|
||||
tensorflow::ops::Div(
|
||||
tensorflow::ops::Sub(
|
||||
resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()),
|
||||
tensorflow::ops::Const({input_std}, b.opts()),
|
||||
b.opts().WithName(output_name));
|
||||
|
||||
// This runs the GraphDef network definition that we've just constructed, and
|
||||
// returns the results in the output tensor.
|
||||
tensorflow::GraphDef graph;
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
|
||||
std::unique_ptr<tensorflow::Session> session(
|
||||
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
TF_RETURN_IF_ERROR(session->Create(graph));
|
||||
TF_RETURN_IF_ERROR(session->Run({}, {output_name}, {}, out_tensors));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Reads a model graph definition from disk, and creates a session object you
|
||||
// can use to run it.
|
||||
Status LoadGraph(string graph_file_name,
|
||||
std::unique_ptr<tensorflow::Session>* session) {
|
||||
tensorflow::GraphDef graph_def;
|
||||
Status load_graph_status =
|
||||
ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
|
||||
if (!load_graph_status.ok()) {
|
||||
return tensorflow::errors::NotFound("Failed to load compute graph at '",
|
||||
graph_file_name, "'");
|
||||
}
|
||||
|
||||
session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
Status session_create_status = (*session)->Create(graph_def);
|
||||
if (!session_create_status.ok()) {
|
||||
return session_create_status;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Analyzes the output of the Inception graph to retrieve the highest scores and
|
||||
// their positions in the tensor, which correspond to categories.
|
||||
Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
|
||||
Tensor* indices, Tensor* scores) {
|
||||
tensorflow::GraphDefBuilder b;
|
||||
string output_name = "top_k";
|
||||
tensorflow::ops::TopK(tensorflow::ops::Const(outputs[0], b.opts()),
|
||||
how_many_labels, b.opts().WithName(output_name));
|
||||
// This runs the GraphDef network definition that we've just constructed, and
|
||||
// returns the results in the output tensors.
|
||||
tensorflow::GraphDef graph;
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
|
||||
std::unique_ptr<tensorflow::Session> session(
|
||||
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
TF_RETURN_IF_ERROR(session->Create(graph));
|
||||
// The TopK node returns two outputs, the scores and their original indices,
|
||||
// so we have to append :0 and :1 to specify them both.
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
|
||||
{}, &out_tensors));
|
||||
*scores = out_tensors[0];
|
||||
*indices = out_tensors[1];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Given the output of a model run, and the name of a file containing the labels
|
||||
// this prints out the top five highest-scoring values.
|
||||
Status PrintTopLabels(const std::vector<Tensor>& outputs,
|
||||
string labels_file_name) {
|
||||
std::vector<string> labels;
|
||||
Status read_labels_status = ReadLabelsFile(labels_file_name, &labels);
|
||||
if (!read_labels_status.ok()) {
|
||||
LOG(ERROR) << read_labels_status;
|
||||
return read_labels_status;
|
||||
}
|
||||
const int how_many_labels = 5;
|
||||
Tensor indices;
|
||||
Tensor scores;
|
||||
TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
|
||||
tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
|
||||
tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
|
||||
for (int pos = 0; pos < how_many_labels; ++pos) {
|
||||
const int label_index = indices_flat(pos);
|
||||
const float score = scores_flat(pos);
|
||||
LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This is a testing function that returns whether the top label index is the
|
||||
// one that's expected.
|
||||
Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
|
||||
bool* is_expected) {
|
||||
*is_expected = false;
|
||||
Tensor indices;
|
||||
Tensor scores;
|
||||
const int how_many_labels = 1;
|
||||
TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
|
||||
tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
|
||||
if (indices_flat(0) != expected) {
|
||||
LOG(ERROR) << "Expected label #" << expected << " but got #"
|
||||
<< indices_flat(0);
|
||||
*is_expected = false;
|
||||
} else {
|
||||
*is_expected = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
// We need to call this to set up global state for TensorFlow.
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
Status s = tensorflow::ParseCommandLineFlags(&argc, argv);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Error parsing command line flags: " << s.ToString();
|
||||
return -1;
|
||||
}
|
||||
|
||||
// First we load and initialize the model.
|
||||
std::unique_ptr<tensorflow::Session> session;
|
||||
string graph_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_graph);
|
||||
Status load_graph_status = LoadGraph(graph_path, &session);
|
||||
if (!load_graph_status.ok()) {
|
||||
LOG(ERROR) << load_graph_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get the image from disk as a float array of numbers, resized and normalized
|
||||
// to the specifications the main graph expects.
|
||||
std::vector<Tensor> resized_tensors;
|
||||
string image_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_image);
|
||||
Status read_tensor_status = ReadTensorFromImageFile(
|
||||
image_path, FLAGS_input_height, FLAGS_input_width, FLAGS_input_mean,
|
||||
FLAGS_input_std, &resized_tensors);
|
||||
if (!read_tensor_status.ok()) {
|
||||
LOG(ERROR) << read_tensor_status;
|
||||
return -1;
|
||||
}
|
||||
const Tensor& resized_tensor = resized_tensors[0];
|
||||
|
||||
// Actually run the image through the model.
|
||||
std::vector<Tensor> outputs;
|
||||
Status run_status = session->Run({{FLAGS_input_layer, resized_tensor}},
|
||||
{FLAGS_output_layer}, {}, &outputs);
|
||||
if (!run_status.ok()) {
|
||||
LOG(ERROR) << "Running model failed: " << run_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// This is for automated testing to make sure we get the expected result with
|
||||
// the default settings. We know that label 866 (military uniform) should be
|
||||
// the top label for the Admiral Hopper image.
|
||||
if (FLAGS_self_test) {
|
||||
bool expected_matches;
|
||||
Status check_status = CheckTopLabel(outputs, 866, &expected_matches);
|
||||
if (!check_status.ok()) {
|
||||
LOG(ERROR) << "Running check failed: " << check_status;
|
||||
return -1;
|
||||
}
|
||||
if (!expected_matches) {
|
||||
LOG(ERROR) << "Self-test failed!";
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Do something interesting with the results we've generated.
|
||||
Status print_status = PrintTopLabels(outputs, FLAGS_labels);
|
||||
if (!print_status.ok()) {
|
||||
LOG(ERROR) << "Running print failed: " << print_status;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -14,3 +14,5 @@ Note: Many practical aspects of usage are covered in the Mechanics tab, and
|
||||
some additional documentation not specific to any particular language API is
|
||||
available in the Resources tab.
|
||||
|
||||
* [Python API](python/index.md)
|
||||
* [C++ API](cc/index.md)
|
||||
|
@ -245,31 +245,54 @@ strided according to the `strides` argument. `strides = [1, 1, 1, 1]` applies
|
||||
the filter to a patch at every offset, `strides = [1, 2, 2, 1]` applies the
|
||||
filter to every other image patch in each dimension, etc.
|
||||
|
||||
Ignoring channels for the moment, the spatial semantics of the convolution ops
|
||||
are as follows. If the 4-D `input` has shape
|
||||
Ignoring channels for the moment, and assume that the the 4-D `input` has shape
|
||||
`[batch, in_height, in_width, ...]` and the 4-D `filter` has shape
|
||||
`[filter_height, filter_width, ...]`, then
|
||||
`[filter_height, filter_width, ...]`, then the spatial semantics of the
|
||||
convolution ops are as follows: first, according to the padding scheme chosen
|
||||
as `'SAME'` or `'VALID'`, the output size and the padding pixels are computed.
|
||||
For the `'SAME'` padding, the output height and width are computed as:
|
||||
|
||||
shape(output) = [batch,
|
||||
(in_height - filter_height + 1) / strides[1],
|
||||
(in_width - filter_width + 1) / strides[2],
|
||||
...]
|
||||
out_height = ceil(float(in_height) / float(strides[1]))
|
||||
out_width = ceil(float(in_width) / float(stides[2]))
|
||||
|
||||
and the padding on the top and left are computed as:
|
||||
|
||||
pad_along_height = ((out_height - 1) * strides[1] +
|
||||
filter_height - in_height)
|
||||
pad_along_width = ((out_width - 1) * strides[2] +
|
||||
filter_width - in_width)
|
||||
pad_top = pad_along_height / 2
|
||||
pad_left = pad_along_width / 2
|
||||
|
||||
Note that the division by 2 means that there might be cases when the padding on
|
||||
both sides (top vs bottom, right vs left) are off by one. In this case, the
|
||||
bottom and right sides always get the one additional padded pixel. For example,
|
||||
when `pad_along_height` is 5, we pad 2 pixels at the top and 3 pixels at the
|
||||
bottom. Note that this is different from existing libraries such as cuDNN and
|
||||
Caffe, which explicitly specify the number of padded pixels and always pad the
|
||||
same number of pixels on both sides.
|
||||
|
||||
For the `'VALID`' padding, the output height and width are computed as:
|
||||
|
||||
out_height = ceil(float(in_height - filter_height + 1) / float(strides[1]))
|
||||
out_width = ceil(float(in_width - filter_width + 1) / float(stides[2]))
|
||||
|
||||
and the padding values are always zero. The output is then computed as
|
||||
|
||||
output[b, i, j, :] =
|
||||
sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] *
|
||||
sum_{di, dj} input[b, strides[1] * i + di - pad_top,
|
||||
strides[2] * j + dj - pad_left, ...] *
|
||||
filter[di, dj, ...]
|
||||
|
||||
where any value outside the original input image region are considered zero (
|
||||
i.e. we pad zero values around the border of the image).
|
||||
|
||||
Since `input` is 4-D, each `input[b, i, j, :]` is a vector. For `conv2d`, these
|
||||
vectors are multiplied by the `filter[di, dj, :, :]` matrices to produce new
|
||||
vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]`
|
||||
is multiplied by a vector `filter[di, dj, k]`, and all the vectors are
|
||||
concatenated.
|
||||
|
||||
In the formula for `shape(output)`, the rounding direction depends on padding:
|
||||
|
||||
* `padding = 'SAME'`: Round down (only full size windows are considered).
|
||||
* `padding = 'VALID'`: Round up (partial windows are included).
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)` <a class="md-anchor" id="conv2d"></a>
|
||||
@ -412,14 +435,8 @@ In detail, the output is
|
||||
|
||||
output[i] = reduce(value[strides * i:strides * i + ksize])
|
||||
|
||||
for each tuple of indices `i`. The output shape is
|
||||
|
||||
shape(output) = (shape(value) - ksize + 1) / strides
|
||||
|
||||
where the rounding direction depends on padding:
|
||||
|
||||
* `padding = 'SAME'`: Round down (only full size windows are considered).
|
||||
* `padding = 'VALID'`: Round up (partial windows are included).
|
||||
where the indices also take into consideration the padding values. Please refer
|
||||
to the `Convolution` section for details about the padding calculation.
|
||||
|
||||
- - -
|
||||
|
||||
|
@ -136,7 +136,7 @@ that the evidence for a class \\(i\\) given an input \\(x\\) is:
|
||||
|
||||
$$\text{evidence}_i = \sum_j W_{i,~ j} x_j + b_i$$
|
||||
|
||||
where \\(W\_i\\) is the weights and \\(b\_i\\) is the bias for class \\(i\\),
|
||||
where \\(W_i\\) is the weights and \\(b_i\\) is the bias for class \\(i\\),
|
||||
and \\(j\\) is an index for summing over the pixels in our input image \\(x\\).
|
||||
We then convert the evidence tallies into our predicted probabilities
|
||||
\\(y\\) using the "softmax" function:
|
||||
|
@ -839,6 +839,7 @@ cpu_only_kernel_test_list = glob([
|
||||
"kernel_tests/save_restore_ops_test.py",
|
||||
"kernel_tests/segment_reduction_ops_test.py",
|
||||
"kernel_tests/sparse_concat_op_test.py",
|
||||
"kernel_tests/sparse_matmul_op_test.py",
|
||||
"kernel_tests/sparse_reorder_op_test.py",
|
||||
"kernel_tests/sparse_to_dense_op_test.py",
|
||||
"kernel_tests/sparsemask_op_test.py",
|
||||
|
@ -333,6 +333,9 @@ class BaseSession(SessionInterface):
|
||||
# Check session.
|
||||
if self._closed:
|
||||
raise RuntimeError('Attempted to use a closed Session.')
|
||||
if self.graph.version == 0:
|
||||
raise RuntimeError('The Session graph is empty. Add operations to the '
|
||||
'graph before calling run().')
|
||||
|
||||
# Validate and process fetches.
|
||||
is_list_fetch = isinstance(fetches, (list, tuple))
|
||||
|
@ -445,6 +445,12 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
sess.close()
|
||||
t.join()
|
||||
|
||||
def testUseEmptyGraph(self):
|
||||
with session.Session() as sess:
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
RuntimeError, lambda e: 'The Session graph is empty.' in str(e)):
|
||||
sess.run([])
|
||||
|
||||
def testNotEntered(self):
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(ops._default_session_stack.get_default(), None)
|
||||
|
@ -214,7 +214,7 @@ import_array();
|
||||
"but got %s" % type(config))
|
||||
status = TF_NewStatus()
|
||||
config_str = config.SerializeToString()
|
||||
_TF_SetConfig(opts, config_str, len(config_str), status)
|
||||
_TF_SetConfig(opts, config_str, status)
|
||||
if TF_GetCode(status) != 0:
|
||||
raise ValueError(TF_Message(status))
|
||||
return opts
|
||||
|
@ -167,7 +167,14 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
|
||||
"""
|
||||
# Type checks for inputs.
|
||||
if not isinstance(graph_def, graph_pb2.GraphDef):
|
||||
raise TypeError('graph_def must be a GraphDef proto.')
|
||||
# `graph_def` could be a dynamically-created message, so try a duck-typed
|
||||
# approach
|
||||
try:
|
||||
old_graph_def = graph_def
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
graph_def.MergeFrom(old_graph_def)
|
||||
except TypeError:
|
||||
raise TypeError('graph_def must be a GraphDef proto.')
|
||||
if input_map is None:
|
||||
input_map = {}
|
||||
else:
|
||||
|
@ -2458,6 +2458,7 @@ class Graph(object):
|
||||
control_ops = []
|
||||
current = self._current_control_dependencies()
|
||||
for c in control_inputs:
|
||||
c = self.as_graph_element(c)
|
||||
if isinstance(c, Tensor):
|
||||
c = c.op
|
||||
elif not isinstance(c, Operation):
|
||||
|
@ -632,6 +632,20 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
||||
# e should be dominated by c.
|
||||
self.assertEqual(e.op.control_inputs, [])
|
||||
|
||||
def testBasicWithConversion(self):
|
||||
g = ops.Graph()
|
||||
a = _apply_op(g, "const", [], [types.float32])
|
||||
|
||||
class ConvertibleObj(object):
|
||||
|
||||
def _as_graph_element(self):
|
||||
return a
|
||||
|
||||
with g.control_dependencies([ConvertibleObj()]):
|
||||
c = _apply_op(g, "const", [], [types.float32])
|
||||
|
||||
self.assertEqual(c.op.control_inputs, [a.op])
|
||||
|
||||
def testNested(self):
|
||||
g = ops.Graph()
|
||||
a_1 = _apply_op(g, "const", [], [types.float32])
|
||||
|
@ -183,6 +183,13 @@ class Conv2DTest(tf.test.TestCase):
|
||||
stride=1, padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2DEmpty(self):
|
||||
expected_output = []
|
||||
self._VerifyValues(tensor_in_sizes=[0, 2, 3, 3],
|
||||
filter_in_sizes=[1, 1, 3, 3],
|
||||
stride=1, padding="VALID",
|
||||
expected=expected_output)
|
||||
|
||||
def testConv2D2x2Filter(self):
|
||||
# The outputs are computed using third_party/py/IPython/notebook.
|
||||
expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
|
||||
@ -1008,4 +1015,5 @@ if __name__ == "__main__":
|
||||
setattr(Conv2DTest, "testInceptionBackFilter_" + str(index),
|
||||
GetInceptionBackFilterTest(input_size_, filter_size_, output_size_,
|
||||
stride_, padding_))
|
||||
|
||||
tf.test.main()
|
||||
|
@ -10,57 +10,56 @@ import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
_TYPES = [tf.int32, tf.int64, tf.float32, tf.float64, tf.string]
|
||||
|
||||
|
||||
class ListDiffTest(tf.test.TestCase):
|
||||
|
||||
def _testListDiff(self, x, y, out, idx, dtype=np.int32):
|
||||
x = np.array(x, dtype=dtype)
|
||||
y = np.array(y, dtype=dtype)
|
||||
out = np.array(out, dtype=dtype)
|
||||
idx = np.array(idx, dtype=dtype)
|
||||
def _testListDiff(self, x, y, out, idx):
|
||||
for dtype in _TYPES:
|
||||
if dtype == tf.string:
|
||||
x = [str(a) for a in x]
|
||||
y = [str(a) for a in y]
|
||||
out = [str(a) for a in out]
|
||||
|
||||
with self.test_session() as sess:
|
||||
x_tensor = tf.convert_to_tensor(x)
|
||||
y_tensor = tf.convert_to_tensor(y)
|
||||
out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
|
||||
tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
|
||||
with self.test_session() as sess:
|
||||
x_tensor = tf.convert_to_tensor(x, dtype=dtype)
|
||||
y_tensor = tf.convert_to_tensor(y, dtype=dtype)
|
||||
out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
|
||||
tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
|
||||
|
||||
self.assertAllEqual(tf_out, out)
|
||||
self.assertAllEqual(tf_idx, idx)
|
||||
self.assertEqual(1, out_tensor.get_shape().ndims)
|
||||
self.assertEqual(1, idx_tensor.get_shape().ndims)
|
||||
self.assertAllEqual(tf_out, out)
|
||||
self.assertAllEqual(tf_idx, idx)
|
||||
self.assertEqual(1, out_tensor.get_shape().ndims)
|
||||
self.assertEqual(1, idx_tensor.get_shape().ndims)
|
||||
|
||||
def testBasic1(self):
|
||||
x = [1, 2, 3, 4]
|
||||
y = [1, 2]
|
||||
out = [3, 4]
|
||||
idx = [2, 3]
|
||||
for t in [np.int32, np.int64, np.float, np.double]:
|
||||
self._testListDiff(x, y, out, idx, dtype=t)
|
||||
self._testListDiff(x, y, out, idx)
|
||||
|
||||
def testBasic2(self):
|
||||
x = [1, 2, 3, 4]
|
||||
y = [2]
|
||||
out = [1, 3, 4]
|
||||
idx = [0, 2, 3]
|
||||
for t in [np.int32, np.int64, np.float, np.double]:
|
||||
self._testListDiff(x, y, out, idx, dtype=t)
|
||||
self._testListDiff(x, y, out, idx)
|
||||
|
||||
def testBasic3(self):
|
||||
x = [1, 4, 3, 2]
|
||||
y = [4, 2]
|
||||
out = [1, 3]
|
||||
idx = [0, 2]
|
||||
for t in [np.int32, np.int64, np.float, np.double]:
|
||||
self._testListDiff(x, y, out, idx, dtype=t)
|
||||
self._testListDiff(x, y, out, idx)
|
||||
|
||||
def testDuplicates(self):
|
||||
x = [1, 2, 4, 3, 2, 3, 3, 1]
|
||||
y = [4, 2]
|
||||
out = [1, 3, 3, 3, 1]
|
||||
idx = [0, 3, 5, 6, 7]
|
||||
for t in [np.int32, np.int64, np.float, np.double]:
|
||||
self._testListDiff(x, y, out, idx, dtype=t)
|
||||
self._testListDiff(x, y, out, idx)
|
||||
|
||||
def testRandom(self):
|
||||
num_random_tests = 10
|
||||
@ -78,38 +77,37 @@ class ListDiffTest(tf.test.TestCase):
|
||||
else:
|
||||
out = []
|
||||
idx = []
|
||||
for t in [np.int32, np.int64, np.float, np.double]:
|
||||
self._testListDiff(x, y, out, idx, dtype=t)
|
||||
self._testListDiff(list(x), list(y), out, idx)
|
||||
|
||||
def testInt32FullyOverlapping(self):
|
||||
def testFullyOverlapping(self):
|
||||
x = [1, 2, 3, 4]
|
||||
y = [1, 2, 3, 4]
|
||||
out = []
|
||||
idx = []
|
||||
self._testListDiff(x, y, out, idx)
|
||||
|
||||
def testInt32NonOverlapping(self):
|
||||
def testNonOverlapping(self):
|
||||
x = [1, 2, 3, 4]
|
||||
y = [5, 6]
|
||||
out = x
|
||||
idx = np.arange(len(x))
|
||||
self._testListDiff(x, y, out, idx)
|
||||
|
||||
def testInt32EmptyX(self):
|
||||
def testEmptyX(self):
|
||||
x = []
|
||||
y = [1, 2]
|
||||
out = []
|
||||
idx = []
|
||||
self._testListDiff(x, y, out, idx)
|
||||
|
||||
def testInt32EmptyY(self):
|
||||
def testEmptyY(self):
|
||||
x = [1, 2, 3, 4]
|
||||
y = []
|
||||
out = x
|
||||
idx = np.arange(len(x))
|
||||
self._testListDiff(x, y, out, idx)
|
||||
|
||||
def testInt32EmptyXY(self):
|
||||
def testEmptyXY(self):
|
||||
x = []
|
||||
y = []
|
||||
out = []
|
||||
|
@ -30,7 +30,7 @@ class InverseOpTest(tf.test.TestCase):
|
||||
self.assertAllClose(np_ans, out)
|
||||
self.assertShapeEqual(y, tf_ans)
|
||||
|
||||
def testBasic(self):
|
||||
def testNonsymmetric(self):
|
||||
# 2x2 matrices
|
||||
matrix1 = np.array([[1., 2.], [3., 4.]])
|
||||
matrix2 = np.array([[1., 3.], [3., 5.]])
|
||||
@ -42,6 +42,18 @@ class InverseOpTest(tf.test.TestCase):
|
||||
matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
|
||||
self._verifyInverse(matrix_batch)
|
||||
|
||||
def testSymmetricPositiveDefinite(self):
|
||||
# 2x2 matrices
|
||||
matrix1 = np.array([[2., 1.], [1., 2.]])
|
||||
matrix2 = np.array([[3., -1.], [-1., 3.]])
|
||||
self._verifyInverse(matrix1)
|
||||
self._verifyInverse(matrix2)
|
||||
# A multidimensional batch of 2x2 matrices
|
||||
matrix_batch = np.concatenate([np.expand_dims(matrix1, 0), np.expand_dims(
|
||||
matrix2, 0)])
|
||||
matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
|
||||
self._verifyInverse(matrix_batch)
|
||||
|
||||
def testNonSquareMatrix(self):
|
||||
# When the inverse of a non-square matrix is attempted we should return
|
||||
# an error
|
||||
@ -58,22 +70,10 @@ class InverseOpTest(tf.test.TestCase):
|
||||
# The input should be invertible.
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("Input is not invertible."):
|
||||
# All rows of the matrix below add to zero
|
||||
# All rows of the matrix below add to zero.
|
||||
tensor3 = tf.constant([[1., 0., -1.], [-1., 1., 0.], [0., -1., 1.]])
|
||||
tf.matrix_inverse(tensor3).eval()
|
||||
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("Input is not invertible."):
|
||||
# Determinant of the matrix below is zero
|
||||
tensor3 = tf.constant([[1., 1.], [1., 1.]])
|
||||
tf.matrix_inverse(tensor3).eval()
|
||||
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("Input is not invertible."):
|
||||
# Determinant of the matrix below is zero
|
||||
tensor3 = tf.constant([[np.inf, 1.], [1., 1.]])
|
||||
tf.matrix_inverse(tensor3).eval()
|
||||
|
||||
def testEmpty(self):
|
||||
self._verifyInverse(np.empty([0, 2, 2]))
|
||||
self._verifyInverse(np.empty([2, 0, 0]))
|
||||
|
@ -175,7 +175,8 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
|
||||
with ops.op_scope(t_list + [clip_norm], name, "clip_by_global_norm") as name:
|
||||
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
|
||||
scale = clip_norm * math_ops.minimum(
|
||||
1.0 / use_norm, constant_op.constant(1.0 / clip_norm))
|
||||
1.0 / use_norm,
|
||||
constant_op.constant(1.0 / clip_norm, dtype=use_norm.dtype))
|
||||
|
||||
values = [
|
||||
ops.convert_to_tensor(
|
||||
|
@ -89,8 +89,8 @@ def bias_add_shape(op):
|
||||
return [output_shape]
|
||||
|
||||
|
||||
def _Get2DOutputSize(input_height, input_width, filter_height, filter_width,
|
||||
row_stride, col_stride, padding_type):
|
||||
def get2d_conv_output_size(input_height, input_width, filter_height,
|
||||
filter_width, row_stride, col_stride, padding_type):
|
||||
"""Returns the number of rows and columns in a convolution/pooling output."""
|
||||
input_height = tensor_shape.as_dimension(input_height)
|
||||
input_width = tensor_shape.as_dimension(input_width)
|
||||
@ -184,7 +184,7 @@ def conv2d_shape(op):
|
||||
# in the kernel implementation.
|
||||
stride = stride_r
|
||||
padding = op.get_attr("padding")
|
||||
out_rows, out_cols = _Get2DOutputSize(
|
||||
out_rows, out_cols = get2d_conv_output_size(
|
||||
in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding)
|
||||
|
||||
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
|
||||
@ -246,7 +246,7 @@ def separable_conv2d_shape(op):
|
||||
# in the kernel implementation.
|
||||
stride = stride_r
|
||||
padding = op.get_attr("padding")
|
||||
out_rows, out_cols = _Get2DOutputSize(
|
||||
out_rows, out_cols = get2d_conv_output_size(
|
||||
in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding)
|
||||
|
||||
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
|
||||
@ -294,7 +294,7 @@ def avg_pool_shape(op):
|
||||
# in the kernel implementation.
|
||||
padding = op.get_attr("padding")
|
||||
|
||||
out_rows, out_cols = _Get2DOutputSize(
|
||||
out_rows, out_cols = get2d_conv_output_size(
|
||||
in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding)
|
||||
|
||||
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth])]
|
||||
@ -346,7 +346,7 @@ def max_pool_shape(op):
|
||||
# in the kernel implementation.
|
||||
if ksize_d == 1:
|
||||
padding = op.get_attr("padding")
|
||||
out_rows, out_cols = _Get2DOutputSize(
|
||||
out_rows, out_cols = get2d_conv_output_size(
|
||||
in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding)
|
||||
return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth])]
|
||||
else:
|
||||
|
@ -42,7 +42,8 @@ def Print(input_, data, message=None, first_n=None, summarize=None,
|
||||
message: A string, prefix of the error message.
|
||||
first_n: Only log `first_n` number of times. Negative numbers log always;
|
||||
this is the default.
|
||||
summarize: Only print this many entries of each tensor.
|
||||
summarize: Only print this many entries of each tensor. If None, then a
|
||||
maximum of 3 elements are printed per input tensor.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
|
@ -38,31 +38,54 @@ strided according to the `strides` argument. `strides = [1, 1, 1, 1]` applies
|
||||
the filter to a patch at every offset, `strides = [1, 2, 2, 1]` applies the
|
||||
filter to every other image patch in each dimension, etc.
|
||||
|
||||
Ignoring channels for the moment, the spatial semantics of the convolution ops
|
||||
are as follows. If the 4-D `input` has shape
|
||||
Ignoring channels for the moment, and assume that the the 4-D `input` has shape
|
||||
`[batch, in_height, in_width, ...]` and the 4-D `filter` has shape
|
||||
`[filter_height, filter_width, ...]`, then
|
||||
`[filter_height, filter_width, ...]`, then the spatial semantics of the
|
||||
convolution ops are as follows: first, according to the padding scheme chosen
|
||||
as `'SAME'` or `'VALID'`, the output size and the padding pixels are computed.
|
||||
For the `'SAME'` padding, the output height and width are computed as:
|
||||
|
||||
shape(output) = [batch,
|
||||
(in_height - filter_height + 1) / strides[1],
|
||||
(in_width - filter_width + 1) / strides[2],
|
||||
...]
|
||||
out_height = ceil(float(in_height) / float(strides[1]))
|
||||
out_width = ceil(float(in_width) / float(stides[2]))
|
||||
|
||||
and the padding on the top and left are computed as:
|
||||
|
||||
pad_along_height = ((out_height - 1) * strides[1] +
|
||||
filter_height - in_height)
|
||||
pad_along_width = ((out_width - 1) * strides[2] +
|
||||
filter_width - in_width)
|
||||
pad_top = pad_along_height / 2
|
||||
pad_left = pad_along_width / 2
|
||||
|
||||
Note that the division by 2 means that there might be cases when the padding on
|
||||
both sides (top vs bottom, right vs left) are off by one. In this case, the
|
||||
bottom and right sides always get the one additional padded pixel. For example,
|
||||
when `pad_along_height` is 5, we pad 2 pixels at the top and 3 pixels at the
|
||||
bottom. Note that this is different from existing libraries such as cuDNN and
|
||||
Caffe, which explicitly specify the number of padded pixels and always pad the
|
||||
same number of pixels on both sides.
|
||||
|
||||
For the `'VALID`' padding, the output height and width are computed as:
|
||||
|
||||
out_height = ceil(float(in_height - filter_height + 1) / float(strides[1]))
|
||||
out_width = ceil(float(in_width - filter_width + 1) / float(stides[2]))
|
||||
|
||||
and the padding values are always zero. The output is then computed as
|
||||
|
||||
output[b, i, j, :] =
|
||||
sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] *
|
||||
sum_{di, dj} input[b, strides[1] * i + di - pad_top,
|
||||
strides[2] * j + dj - pad_left, ...] *
|
||||
filter[di, dj, ...]
|
||||
|
||||
where any value outside the original input image region are considered zero (
|
||||
i.e. we pad zero values around the border of the image).
|
||||
|
||||
Since `input` is 4-D, each `input[b, i, j, :]` is a vector. For `conv2d`, these
|
||||
vectors are multiplied by the `filter[di, dj, :, :]` matrices to produce new
|
||||
vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]`
|
||||
is multiplied by a vector `filter[di, dj, k]`, and all the vectors are
|
||||
concatenated.
|
||||
|
||||
In the formula for `shape(output)`, the rounding direction depends on padding:
|
||||
|
||||
* `padding = 'SAME'`: Round down (only full size windows are considered).
|
||||
* `padding = 'VALID'`: Round up (partial windows are included).
|
||||
|
||||
@@conv2d
|
||||
@@depthwise_conv2d
|
||||
@@separable_conv2d
|
||||
@ -79,14 +102,8 @@ In detail, the output is
|
||||
|
||||
output[i] = reduce(value[strides * i:strides * i + ksize])
|
||||
|
||||
for each tuple of indices `i`. The output shape is
|
||||
|
||||
shape(output) = (shape(value) - ksize + 1) / strides
|
||||
|
||||
where the rounding direction depends on padding:
|
||||
|
||||
* `padding = 'SAME'`: Round down (only full size windows are considered).
|
||||
* `padding = 'VALID'`: Round up (partial windows are included).
|
||||
where the indices also take into consideration the padding values. Please refer
|
||||
to the `Convolution` section for details about the padding calculation.
|
||||
|
||||
@@avg_pool
|
||||
@@max_pool
|
||||
|
@ -143,6 +143,8 @@ class EventAccumulator(object):
|
||||
self._is_autoupdating = False
|
||||
self._activated = False
|
||||
self._compression_bps = compression_bps
|
||||
self.most_recent_step = -1
|
||||
self.most_recent_wall_time = -1
|
||||
|
||||
def Reload(self):
|
||||
"""Loads all events added since the last call to `Reload`.
|
||||
@ -156,6 +158,31 @@ class EventAccumulator(object):
|
||||
self._activated = True
|
||||
with self._generator_mutex:
|
||||
for event in self._generator.Load():
|
||||
## Check if the event happened after a crash
|
||||
if event.step < self.most_recent_step:
|
||||
|
||||
## Keep data in reservoirs that has a step less than event.step
|
||||
_NotExpired = lambda x: x.step < event.step
|
||||
num_expired_scalars = self._scalars.FilterItems(_NotExpired)
|
||||
num_expired_histograms = self._histograms.FilterItems(_NotExpired)
|
||||
num_expired_compressed_histograms = self._compressed_histograms.FilterItems(
|
||||
_NotExpired)
|
||||
num_expired_images = self._images.FilterItems(_NotExpired)
|
||||
|
||||
purge_msg = (
|
||||
'Detected out of order event.step likely caused by a Tensorflow '
|
||||
'restart. Purging expired events from Tensorboard display '
|
||||
'between the previous step: {} (timestamp: {}) and current step:'
|
||||
' {} (timestamp: {}). Removing {} scalars, {} histograms, {} '
|
||||
'compressed histograms, and {} images.').format(
|
||||
self.most_recent_step, self.most_recent_wall_time, event.step,
|
||||
event.wall_time, num_expired_scalars, num_expired_histograms,
|
||||
num_expired_compressed_histograms, num_expired_images)
|
||||
logging.warn(purge_msg)
|
||||
else:
|
||||
self.most_recent_step = event.step
|
||||
self.most_recent_wall_time = event.wall_time
|
||||
## Process the event
|
||||
if event.HasField('graph_def'):
|
||||
if self._graph is not None:
|
||||
logging.warn(('Found more than one graph event per run.'
|
||||
|
@ -102,8 +102,8 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
|
||||
|
||||
def testTags(self):
|
||||
gen = _EventGenerator()
|
||||
gen.AddScalar('sv1')
|
||||
gen.AddScalar('sv2')
|
||||
gen.AddScalar('s1')
|
||||
gen.AddScalar('s2')
|
||||
gen.AddHistogram('hst1')
|
||||
gen.AddHistogram('hst2')
|
||||
gen.AddImage('im1')
|
||||
@ -113,7 +113,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
|
||||
self.assertTagsEqual(
|
||||
acc.Tags(), {
|
||||
ea.IMAGES: ['im1', 'im2'],
|
||||
ea.SCALARS: ['sv1', 'sv2'],
|
||||
ea.SCALARS: ['s1', 's2'],
|
||||
ea.HISTOGRAMS: ['hst1', 'hst2'],
|
||||
ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
|
||||
ea.GRAPH: False})
|
||||
@ -123,8 +123,8 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
|
||||
acc = ea.EventAccumulator(gen)
|
||||
acc.Reload()
|
||||
self.assertEqual(acc.Tags(), self.empty)
|
||||
gen.AddScalar('sv1')
|
||||
gen.AddScalar('sv2')
|
||||
gen.AddScalar('s1')
|
||||
gen.AddScalar('s2')
|
||||
gen.AddHistogram('hst1')
|
||||
gen.AddHistogram('hst2')
|
||||
gen.AddImage('im1')
|
||||
@ -133,7 +133,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
|
||||
acc.Reload()
|
||||
self.assertTagsEqual(acc.Tags(), {
|
||||
ea.IMAGES: ['im1', 'im2'],
|
||||
ea.SCALARS: ['sv1', 'sv2'],
|
||||
ea.SCALARS: ['s1', 's2'],
|
||||
ea.HISTOGRAMS: ['hst1', 'hst2'],
|
||||
ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
|
||||
ea.GRAPH: False})
|
||||
@ -141,13 +141,13 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
|
||||
def testScalars(self):
|
||||
gen = _EventGenerator()
|
||||
acc = ea.EventAccumulator(gen)
|
||||
sv1 = ea.ScalarEvent(wall_time=1, step=10, value=32)
|
||||
sv2 = ea.ScalarEvent(wall_time=2, step=12, value=64)
|
||||
gen.AddScalar('sv1', wall_time=1, step=10, value=32)
|
||||
gen.AddScalar('sv2', wall_time=2, step=12, value=64)
|
||||
s1 = ea.ScalarEvent(wall_time=1, step=10, value=32)
|
||||
s2 = ea.ScalarEvent(wall_time=2, step=12, value=64)
|
||||
gen.AddScalar('s1', wall_time=1, step=10, value=32)
|
||||
gen.AddScalar('s2', wall_time=2, step=12, value=64)
|
||||
acc.Reload()
|
||||
self.assertEqual(acc.Scalars('sv1'), [sv1])
|
||||
self.assertEqual(acc.Scalars('sv2'), [sv2])
|
||||
self.assertEqual(acc.Scalars('s1'), [s1])
|
||||
self.assertEqual(acc.Scalars('s2'), [s2])
|
||||
|
||||
def testHistograms(self):
|
||||
gen = _EventGenerator()
|
||||
@ -311,7 +311,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
|
||||
with self.assertRaises(RuntimeError):
|
||||
acc.Tags()
|
||||
with self.assertRaises(RuntimeError):
|
||||
acc.Scalars('sv1')
|
||||
acc.Scalars('s1')
|
||||
acc.Reload()
|
||||
self.assertTrue(acc._activated)
|
||||
acc._activated = False
|
||||
@ -321,17 +321,17 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
|
||||
acc = ea.EventAccumulator(gen)
|
||||
acc.Reload()
|
||||
with self.assertRaises(KeyError):
|
||||
acc.Scalars('sv1')
|
||||
acc.Scalars('s1')
|
||||
with self.assertRaises(KeyError):
|
||||
acc.Scalars('hst1')
|
||||
with self.assertRaises(KeyError):
|
||||
acc.Scalars('im1')
|
||||
with self.assertRaises(KeyError):
|
||||
acc.Histograms('sv1')
|
||||
acc.Histograms('s1')
|
||||
with self.assertRaises(KeyError):
|
||||
acc.Histograms('im1')
|
||||
with self.assertRaises(KeyError):
|
||||
acc.Images('sv1')
|
||||
acc.Images('s1')
|
||||
with self.assertRaises(KeyError):
|
||||
acc.Images('hst1')
|
||||
|
||||
@ -339,21 +339,43 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
|
||||
"""Tests that non-value events in the generator don't cause early exits."""
|
||||
gen = _EventGenerator()
|
||||
acc = ea.EventAccumulator(gen)
|
||||
gen.AddScalar('sv1', wall_time=1, step=10, value=20)
|
||||
gen.AddScalar('s1', wall_time=1, step=10, value=20)
|
||||
gen.AddEvent(tf.Event(
|
||||
wall_time=2, step=20, file_version='notsv2'))
|
||||
gen.AddScalar('sv3', wall_time=3, step=100, value=1)
|
||||
wall_time=2, step=20, file_version='nots2'))
|
||||
gen.AddScalar('s3', wall_time=3, step=100, value=1)
|
||||
gen.AddHistogram('hst1')
|
||||
gen.AddImage('im1')
|
||||
|
||||
acc.Reload()
|
||||
self.assertTagsEqual(acc.Tags(), {
|
||||
ea.IMAGES: ['im1'],
|
||||
ea.SCALARS: ['sv1', 'sv3'],
|
||||
ea.SCALARS: ['s1', 's3'],
|
||||
ea.HISTOGRAMS: ['hst1'],
|
||||
ea.COMPRESSED_HISTOGRAMS: ['hst1'],
|
||||
ea.GRAPH: False})
|
||||
|
||||
def testExpiredDataDiscardedAfterRestart(self):
|
||||
"""Tests that events are discarded after a restart is detected.
|
||||
|
||||
If a step value is observed to be lower than what was previously seen,
|
||||
this should force a discard of all previous items that are outdated.
|
||||
"""
|
||||
gen = _EventGenerator()
|
||||
acc = ea.EventAccumulator(gen)
|
||||
gen.AddScalar('s1', wall_time=1, step=100, value=20)
|
||||
gen.AddScalar('s1', wall_time=1, step=200, value=20)
|
||||
gen.AddScalar('s1', wall_time=1, step=300, value=20)
|
||||
acc.Reload()
|
||||
## Check that number of items are what they should be
|
||||
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300])
|
||||
|
||||
gen.AddScalar('s1', wall_time=1, step=101, value=20)
|
||||
gen.AddScalar('s1', wall_time=1, step=201, value=20)
|
||||
gen.AddScalar('s1', wall_time=1, step=301, value=20)
|
||||
acc.Reload()
|
||||
## Check that we have discarded 200 and 300
|
||||
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
|
||||
|
||||
|
||||
class RealisticEventAccumulatorTest(EventAccumulatorTest):
|
||||
|
||||
|
@ -128,13 +128,15 @@ class EventMultiplexer(object):
|
||||
return self
|
||||
|
||||
def AddRunsFromDirectory(self, path, name=None):
|
||||
"""Load runs from a directory, assuming each subdirectory is a run.
|
||||
"""Load runs from a directory; recursively walks subdirectories.
|
||||
|
||||
If path doesn't exist, no-op. This ensures that it is safe to call
|
||||
`AddRunsFromDirectory` multiple times, even before the directory is made.
|
||||
|
||||
If the directory contains TensorFlow event files, it is itself treated as a
|
||||
run.
|
||||
If path is a directory, load event files in the directory (if any exist) and
|
||||
recursively call AddRunsFromDirectory on any subdirectories. This mean you
|
||||
can call AddRunsFromDirectory at the root of a tree of event logs and
|
||||
TensorBoard will load them all.
|
||||
|
||||
If the `EventMultiplexer` is already loaded or autoupdating, this will cause
|
||||
the newly created accumulators to also `Reload()` or `AutoUpdate()`.
|
||||
@ -156,25 +158,16 @@ class EventMultiplexer(object):
|
||||
if not gfile.Exists(path):
|
||||
return # Maybe it hasn't been created yet, fail silently to retry later
|
||||
if not gfile.IsDirectory(path):
|
||||
raise ValueError('Path exists and is not a directory, %s' % path)
|
||||
paths = gfile.ListDirectory(path)
|
||||
is_directory = lambda x: gfile.IsDirectory(os.path.join(path, x))
|
||||
subdirectories = filter(is_directory, paths)
|
||||
for s in subdirectories:
|
||||
if name:
|
||||
subname = '/'.join([name, s])
|
||||
else:
|
||||
subname = s
|
||||
self.AddRun(os.path.join(path, s), subname)
|
||||
raise ValueError('AddRunsFromDirectory: path exists and is not a '
|
||||
'directory, %s' % path)
|
||||
|
||||
for (subdir, _, files) in os.walk(path):
|
||||
if list(filter(event_accumulator.IsTensorFlowEventsFile, files)):
|
||||
logging.info('Adding events from directory %s', subdir)
|
||||
rpath = os.path.relpath(subdir, path)
|
||||
subname = os.path.join(name, rpath) if name else rpath
|
||||
self.AddRun(subdir, name=subname)
|
||||
|
||||
if list(filter(event_accumulator.IsTensorFlowEventsFile, paths)):
|
||||
directory_name = os.path.split(path)[1]
|
||||
logging.info('Directory %s has event files; loading', directory_name)
|
||||
if name:
|
||||
dname = name
|
||||
else:
|
||||
dname = directory_name
|
||||
self.AddRun(path, dname)
|
||||
return self
|
||||
|
||||
def Reload(self):
|
||||
|
@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import os.path
|
||||
|
||||
import tensorflow.python.platform
|
||||
|
||||
@ -13,6 +14,20 @@ from tensorflow.python.summary import event_accumulator
|
||||
from tensorflow.python.summary import event_multiplexer
|
||||
|
||||
|
||||
def _AddEvents(path):
|
||||
if not gfile.IsDirectory(path):
|
||||
gfile.MakeDirs(path)
|
||||
fpath = os.path.join(path, 'hypothetical.tfevents.out')
|
||||
with gfile.GFile(fpath, 'w'):
|
||||
return fpath
|
||||
|
||||
|
||||
def _CreateCleanDirectory(path):
|
||||
if gfile.IsDirectory(path):
|
||||
gfile.DeleteRecursively(path)
|
||||
gfile.MkDir(path)
|
||||
|
||||
|
||||
class _FakeAccumulator(object):
|
||||
|
||||
def __init__(self, path):
|
||||
@ -122,34 +137,33 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
|
||||
x.AddRunsFromDirectory(fakedir)
|
||||
self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect')
|
||||
|
||||
if gfile.IsDirectory(realdir):
|
||||
gfile.DeleteRecursively(realdir)
|
||||
gfile.MkDir(realdir)
|
||||
_CreateCleanDirectory(realdir)
|
||||
x.AddRunsFromDirectory(realdir)
|
||||
self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect')
|
||||
|
||||
path1 = join(realdir, 'path1')
|
||||
gfile.MkDir(path1)
|
||||
x.AddRunsFromDirectory(realdir)
|
||||
self.assertEqual(sorted(x.Runs().keys()), ['path1'], 'loaded run: path1')
|
||||
self.assertEqual(x.Runs(), {}, 'creating empty subdirectory had no effect')
|
||||
|
||||
_AddEvents(path1)
|
||||
x.AddRunsFromDirectory(realdir)
|
||||
self.assertItemsEqual(x.Runs(), ['path1'], 'loaded run: path1')
|
||||
loader1 = x._GetAccumulator('path1')
|
||||
self.assertEqual(loader1._path, path1, 'has the correct path')
|
||||
|
||||
path2 = join(realdir, 'path2')
|
||||
gfile.MkDir(path2)
|
||||
_AddEvents(path2)
|
||||
x.AddRunsFromDirectory(realdir)
|
||||
self.assertItemsEqual(sorted(x.Runs().keys()), ['path1', 'path2'])
|
||||
self.assertItemsEqual(x.Runs(), ['path1', 'path2'])
|
||||
self.assertEqual(x._GetAccumulator('path1'), loader1,
|
||||
'loader1 not regenerated')
|
||||
loader2 = x._GetAccumulator('path2')
|
||||
|
||||
path2_2 = join(path2, 'path2')
|
||||
gfile.MkDir(path2_2)
|
||||
x.AddRunsFromDirectory(path2)
|
||||
self.assertItemsEqual(sorted(x.Runs().keys()), ['path1', 'path2'])
|
||||
self.assertNotEqual(loader2, x._GetAccumulator('path2'),
|
||||
'loader2 regenerated')
|
||||
self.assertEqual(x._GetAccumulator('path2')._path, path2_2,
|
||||
_AddEvents(path2_2)
|
||||
x.AddRunsFromDirectory(realdir)
|
||||
self.assertItemsEqual(x.Runs(), ['path1', 'path2', 'path2/path2'])
|
||||
self.assertEqual(x._GetAccumulator('path2/path2')._path, path2_2,
|
||||
'loader2 path correct')
|
||||
|
||||
def testAddRunsFromDirectoryThatContainsEvents(self):
|
||||
@ -158,21 +172,18 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
|
||||
join = os.path.join
|
||||
realdir = join(tmpdir, 'event_containing_directory')
|
||||
|
||||
if gfile.IsDirectory(realdir):
|
||||
gfile.DeleteRecursively(realdir)
|
||||
gfile.MkDir(realdir)
|
||||
_CreateCleanDirectory(realdir)
|
||||
|
||||
self.assertEqual(x.Runs(), {})
|
||||
|
||||
with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'):
|
||||
pass
|
||||
_AddEvents(realdir)
|
||||
x.AddRunsFromDirectory(realdir)
|
||||
self.assertItemsEqual(x.Runs(), ['event_containing_directory'])
|
||||
self.assertItemsEqual(x.Runs(), ['.'])
|
||||
|
||||
subdir = join(realdir, 'subdir')
|
||||
gfile.MkDir(subdir)
|
||||
_AddEvents(subdir)
|
||||
x.AddRunsFromDirectory(realdir)
|
||||
self.assertItemsEqual(x.Runs(), ['event_containing_directory', 'subdir'])
|
||||
self.assertItemsEqual(x.Runs(), ['.', 'subdir'])
|
||||
|
||||
def testAddRunsFromDirectoryWithRunNames(self):
|
||||
x = event_multiplexer.EventMultiplexer()
|
||||
@ -180,30 +191,45 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase):
|
||||
join = os.path.join
|
||||
realdir = join(tmpdir, 'event_containing_directory')
|
||||
|
||||
if gfile.IsDirectory(realdir):
|
||||
gfile.DeleteRecursively(realdir)
|
||||
gfile.MkDir(realdir)
|
||||
_CreateCleanDirectory(realdir)
|
||||
|
||||
self.assertEqual(x.Runs(), {})
|
||||
|
||||
with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'):
|
||||
pass
|
||||
_AddEvents(realdir)
|
||||
x.AddRunsFromDirectory(realdir, 'foo')
|
||||
self.assertItemsEqual(x.Runs(), ['foo'])
|
||||
self.assertItemsEqual(x.Runs(), ['foo/.'])
|
||||
|
||||
subdir = join(realdir, 'subdir')
|
||||
gfile.MkDir(subdir)
|
||||
_AddEvents(subdir)
|
||||
x.AddRunsFromDirectory(realdir, 'foo')
|
||||
self.assertItemsEqual(x.Runs(), ['foo', 'foo/subdir'])
|
||||
self.assertItemsEqual(x.Runs(), ['foo/.', 'foo/subdir'])
|
||||
|
||||
def testAddRunsFromDirectoryWalksTree(self):
|
||||
x = event_multiplexer.EventMultiplexer()
|
||||
tmpdir = self.get_temp_dir()
|
||||
join = os.path.join
|
||||
realdir = join(tmpdir, 'event_containing_directory')
|
||||
|
||||
_CreateCleanDirectory(realdir)
|
||||
_AddEvents(realdir)
|
||||
sub = join(realdir, 'subdirectory')
|
||||
sub1 = join(sub, '1')
|
||||
sub2 = join(sub, '2')
|
||||
sub1_1 = join(sub1, '1')
|
||||
_AddEvents(sub1)
|
||||
_AddEvents(sub2)
|
||||
_AddEvents(sub1_1)
|
||||
x.AddRunsFromDirectory(realdir)
|
||||
|
||||
self.assertItemsEqual(x.Runs(), ['.',
|
||||
'subdirectory/1', 'subdirectory/2',
|
||||
'subdirectory/1/1'])
|
||||
|
||||
def testAddRunsFromDirectoryThrowsException(self):
|
||||
x = event_multiplexer.EventMultiplexer()
|
||||
tmpdir = self.get_temp_dir()
|
||||
|
||||
filepath = os.path.join(tmpdir, 'bad_file')
|
||||
with gfile.GFile(filepath, 'w'):
|
||||
pass
|
||||
|
||||
filepath = _AddEvents(tmpdir)
|
||||
with self.assertRaises(ValueError):
|
||||
x.AddRunsFromDirectory(filepath)
|
||||
|
||||
|
@ -77,7 +77,7 @@ class Reservoir(object):
|
||||
key: The key for which we are finding associated items.
|
||||
|
||||
Raises:
|
||||
KeyError: If the key is not ofund in the reservoir.
|
||||
KeyError: If the key is not found in the reservoir.
|
||||
|
||||
Returns:
|
||||
[list, of, items] associated with that key.
|
||||
@ -102,6 +102,19 @@ class Reservoir(object):
|
||||
bucket = self._buckets[key]
|
||||
bucket.AddItem(item)
|
||||
|
||||
def FilterItems(self, filterFn):
|
||||
"""Filter items within a Reservoir, using a filtering function.
|
||||
|
||||
Args:
|
||||
filterFn: A function that returns True for the items to be kept.
|
||||
|
||||
Returns:
|
||||
The number of items removed.
|
||||
"""
|
||||
with self._mutex:
|
||||
return sum(bucket.FilterItems(filterFn)
|
||||
for bucket in self._buckets.values())
|
||||
|
||||
|
||||
class _ReservoirBucket(object):
|
||||
"""A container for items from a stream, that implements reservoir sampling.
|
||||
@ -128,7 +141,7 @@ class _ReservoirBucket(object):
|
||||
# AddItem are thread-safe
|
||||
self._mutex = threading.Lock()
|
||||
self._max_size = _max_size
|
||||
self._count = 0
|
||||
self._num_items_seen = 0
|
||||
if _random is not None:
|
||||
self._random = _random
|
||||
else:
|
||||
@ -139,13 +152,13 @@ class _ReservoirBucket(object):
|
||||
|
||||
The new item is guaranteed to be added to the bucket, and to be the last
|
||||
element in the bucket. If the bucket has reached capacity, then an old item
|
||||
will be replaced. With probability (_max_size/_count) a random item in the
|
||||
bucket will be popped out and the new item will be appended to the end. With
|
||||
probability (1 - _max_size/_count) the last item in the bucket will be
|
||||
replaced.
|
||||
will be replaced. With probability (_max_size/_num_items_seen) a random item
|
||||
in the bucket will be popped out and the new item will be appended
|
||||
to the end. With probability (1 - _max_size/_num_items_seen)
|
||||
the last item in the bucket will be replaced.
|
||||
|
||||
Since the O(n) replacements occur with O(1/_count) liklihood, the amortized
|
||||
runtime is O(1).
|
||||
Since the O(n) replacements occur with O(1/_num_items_seen) likelihood,
|
||||
the amortized runtime is O(1).
|
||||
|
||||
Args:
|
||||
item: The item to add to the bucket.
|
||||
@ -154,13 +167,43 @@ class _ReservoirBucket(object):
|
||||
if len(self.items) < self._max_size or self._max_size == 0:
|
||||
self.items.append(item)
|
||||
else:
|
||||
r = self._random.randint(0, self._count)
|
||||
r = self._random.randint(0, self._num_items_seen)
|
||||
if r < self._max_size:
|
||||
self.items.pop(r)
|
||||
self.items.append(item)
|
||||
else:
|
||||
self.items[-1] = item
|
||||
self._count += 1
|
||||
self._num_items_seen += 1
|
||||
|
||||
def FilterItems(self, filterFn):
|
||||
"""Filter items in a ReservoirBucket, using a filtering function.
|
||||
|
||||
Filtering items from the reservoir bucket must update the
|
||||
internal state variable self._num_items_seen, which is used for determining
|
||||
the rate of replacement in reservoir sampling. Ideally, self._num_items_seen
|
||||
would contain the exact number of items that have ever seen by the
|
||||
ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not
|
||||
have access to all items seen -- it only has access to the subset of items
|
||||
that have survived sampling (self.items). Therefore, we estimate
|
||||
self._num_items_seen by scaling it by the same ratio as the ratio of items
|
||||
not removed from self.items.
|
||||
|
||||
Args:
|
||||
filterFn: A function that returns True for items to be kept.
|
||||
|
||||
Returns:
|
||||
The number of items removed from the bucket.
|
||||
"""
|
||||
with self._mutex:
|
||||
size_before = len(self.items)
|
||||
self.items = filter(filterFn, self.items)
|
||||
size_diff = size_before - len(self.items)
|
||||
|
||||
# Estimate a correction the the number of items seen
|
||||
prop_remaining = len(self.items) / float(
|
||||
size_before) if size_before > 0 else 0
|
||||
self._num_items_seen = int(round(self._num_items_seen * prop_remaining))
|
||||
return size_diff
|
||||
|
||||
def Items(self):
|
||||
"""Get all the items in the bucket."""
|
||||
|
@ -90,12 +90,14 @@ class ReservoirBucketTest(googletest.TestCase):
|
||||
for i in xrange(100):
|
||||
b.AddItem(i)
|
||||
self.assertEqual(b.Items(), list(xrange(100)))
|
||||
self.assertEqual(b._num_items_seen, 100)
|
||||
|
||||
def testDoesntOverfill(self):
|
||||
b = reservoir._ReservoirBucket(10)
|
||||
for i in xrange(1000):
|
||||
b.AddItem(i)
|
||||
self.assertEqual(len(b.Items()), 10)
|
||||
self.assertEqual(b._num_items_seen, 1000)
|
||||
|
||||
def testMaintainsOrder(self):
|
||||
b = reservoir._ReservoirBucket(100)
|
||||
@ -119,12 +121,14 @@ class ReservoirBucketTest(googletest.TestCase):
|
||||
for i in xrange(20):
|
||||
b.AddItem(i)
|
||||
self.assertEqual(b.Items(), [i])
|
||||
self.assertEqual(b._num_items_seen, 20)
|
||||
|
||||
def testSizeZeroBucket(self):
|
||||
b = reservoir._ReservoirBucket(0)
|
||||
for i in xrange(20):
|
||||
b.AddItem(i)
|
||||
self.assertEqual(b.Items(), list(range(i + 1)))
|
||||
self.assertEqual(b._num_items_seen, 20)
|
||||
|
||||
def testSizeRequirement(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@ -132,6 +136,29 @@ class ReservoirBucketTest(googletest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
reservoir._ReservoirBucket(10.3)
|
||||
|
||||
def testRemovesItems(self):
|
||||
b = reservoir._ReservoirBucket(100)
|
||||
for i in xrange(10):
|
||||
b.AddItem(i)
|
||||
self.assertEqual(len(b.Items()), 10)
|
||||
self.assertEqual(b._num_items_seen, 10)
|
||||
self.assertEqual(b.FilterItems(lambda x: x <= 7), 2)
|
||||
self.assertEqual(len(b.Items()), 8)
|
||||
self.assertEqual(b._num_items_seen, 8)
|
||||
|
||||
def testRemovesItemsWhenItemsAreReplaced(self):
|
||||
b = reservoir._ReservoirBucket(100)
|
||||
for i in xrange(10000):
|
||||
b.AddItem(i)
|
||||
self.assertEqual(b._num_items_seen, 10000)
|
||||
|
||||
# Remove items
|
||||
num_removed = b.FilterItems(lambda x: x <= 7)
|
||||
self.assertGreater(num_removed, 92)
|
||||
self.assertEqual([], [item for item in b.Items() if item > 7])
|
||||
self.assertEqual(b._num_items_seen,
|
||||
int(round(10000 * (1 - float(num_removed) / 100))))
|
||||
|
||||
|
||||
class ReservoirBucketStatisticalDistributionTest(googletest.TestCase):
|
||||
|
||||
|
@ -445,7 +445,8 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
|
||||
capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
|
||||
dtypes=dtypes, shapes=shapes)
|
||||
_enqueue(queue, tensor_list, num_threads, enqueue_many)
|
||||
full = (math_ops.cast(queue.size() - min_after_dequeue, types.float32) *
|
||||
full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
|
||||
types.float32) *
|
||||
(1. / (capacity - min_after_dequeue)))
|
||||
# Note that name contains a '/' at the end so we intentionally do not place
|
||||
# a '/' after %s below.
|
||||
@ -513,7 +514,8 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
|
||||
capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
|
||||
dtypes=dtypes, shapes=shapes)
|
||||
_enqueue_join(queue, tensor_list_list, enqueue_many)
|
||||
full = (math_ops.cast(queue.size() - min_after_dequeue, types.float32) *
|
||||
full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
|
||||
types.float32) *
|
||||
(1. / (capacity - min_after_dequeue)))
|
||||
# Note that name contains a '/' at the end so we intentionally do not place
|
||||
# a '/' after %s below.
|
||||
|
@ -4,7 +4,7 @@
|
||||
module TF {
|
||||
type TFDatum = [number, number, number];
|
||||
type tooltipMap = {[run: string]: string};
|
||||
type TooltipUpdater = (tooltipMap, xValue, closestRun) => void;
|
||||
export type TooltipUpdater = (tooltipMap, xValue, closestRun) => void;
|
||||
|
||||
let Y_TOOLTIP_FORMATTER_PRECISION = 4;
|
||||
let STEP_AXIS_FORMATTER_PRECISION = 4;
|
||||
|
@ -39,13 +39,13 @@ export class SlimGraph {
|
||||
}
|
||||
}
|
||||
|
||||
interface NormalizedInput {
|
||||
export interface NormalizedInput {
|
||||
name: string;
|
||||
hasNumberPart: boolean;
|
||||
isControlDependency: boolean;
|
||||
}
|
||||
|
||||
interface BuildParams {
|
||||
export interface BuildParams {
|
||||
enableEmbedding: boolean;
|
||||
inEmbeddingTypes: string[];
|
||||
outEmbeddingTypes: string[];
|
||||
@ -352,7 +352,7 @@ export function joinStatsInfoWithGraph(graph: SlimGraph,
|
||||
/**
|
||||
* Execution stats for the node.
|
||||
*/
|
||||
class NodeStats {
|
||||
export class NodeStats {
|
||||
constructor(totalBytes: number, totalMicros: number, outputSize: number[][]) {
|
||||
this.totalBytes = totalBytes;
|
||||
this.totalMicros = totalMicros;
|
||||
|
@ -11,7 +11,7 @@ const LOG_PREFIX_MSG = "Graph hierarchy: ";
|
||||
/**
|
||||
* Class used as output for getPredecessors and getSuccessors methods
|
||||
*/
|
||||
interface Edges {
|
||||
export interface Edges {
|
||||
control: string[];
|
||||
regular: string[];
|
||||
}
|
||||
@ -370,7 +370,7 @@ function findEdgeTargetsInGraph(
|
||||
});
|
||||
}
|
||||
|
||||
interface HierarchyParams {
|
||||
export interface HierarchyParams {
|
||||
verifyTemplate: boolean;
|
||||
seriesNodeMinSize: number;
|
||||
}
|
||||
@ -640,7 +640,7 @@ function detectSeries(clusters: {[clusterId: string]: string[]},
|
||||
* which is an array that contains objects with name, id, prefix, suffix,
|
||||
* and parent properties.
|
||||
*/
|
||||
let candidatesDict = {};
|
||||
let candidatesDict: {[seriesName: string]: SeriesNode[]} = {};
|
||||
|
||||
// Group all nodes that have the same name, with the exception of a
|
||||
// number at the end of the name after an underscore, which is allowed to
|
||||
|
@ -65,7 +65,7 @@ export let SeriesNodeColors = {
|
||||
/**
|
||||
* Parameters that affect how the graph is rendered on the screen.
|
||||
*/
|
||||
interface RenderGraphParams {
|
||||
export interface RenderGraphParams {
|
||||
/**
|
||||
* Whether to extract high degree nodes from the core part of the graph.
|
||||
*/
|
||||
|
@ -28,7 +28,7 @@ export function detect(h, verifyTemplate): {[templateId: string]: string[]} {
|
||||
// Sort the templates by minimum level in the graph at which they appear,
|
||||
// as this leads to optimal setting of the colors of each template for
|
||||
// maximum differentiation.
|
||||
return _(templates).pairs()
|
||||
return <{[templateId: string]: string[]}> _(templates).pairs()
|
||||
.sortBy(function(pair) {
|
||||
return pair[1].level;
|
||||
})
|
||||
@ -101,6 +101,7 @@ function clusterSimilarSubgraphs(h: hierarchy.Hierarchy) {
|
||||
function groupTemplateAndAssignId(nnGroups, verifyTemplate) {
|
||||
// For each metanode, compare its subgraph (starting from shallower groups)
|
||||
// and assign template id.
|
||||
let result: {[templateId: string]: {level: number, nodes: string[]}} = {};
|
||||
return _.reduce(nnGroups, function(templates, nnGroupPair) {
|
||||
let signature = nnGroupPair[0],
|
||||
nnGroup = nnGroupPair[1].nodes,
|
||||
@ -137,7 +138,7 @@ function groupTemplateAndAssignId(nnGroups, verifyTemplate) {
|
||||
};
|
||||
});
|
||||
return templates;
|
||||
}, {});
|
||||
}, result);
|
||||
}
|
||||
|
||||
function sortNodes(names: string[], graph: graphlib.Graph<Metanode|OpNode, Metaedge>,
|
||||
|
@ -1,6 +1,7 @@
|
||||
<script src="../../bower_components/d3/d3.js"></script>
|
||||
<script src="../../bower_components/lodash/lodash.js"></script>
|
||||
<script src="../../bower_components/graphlib/dist/graphlib.core.js"></script>
|
||||
<script src="../../bower_components/dagre/dist/dagre.core.js"></script>
|
||||
|
||||
<script src="lib/common.js"></script>
|
||||
<script src="lib/graph.js"></script>
|
||||
|
@ -1,8 +1,7 @@
|
||||
<link rel="import" href="../../bower_components/polymer/polymer.html">
|
||||
<link rel="import" href="tf-graph-style.html">
|
||||
<link rel="import" href="tf-graph-minimap.html">
|
||||
<script src="../tf-graph-common/lib/layout.js"></script>
|
||||
<script src="../../bower_components/dagre/dist/dagre.core.js"></script>
|
||||
|
||||
<!--
|
||||
A module that takes a render hierarchy as input and produces an SVG DOM using
|
||||
dagre and d3.
|
||||
|
@ -23,27 +23,25 @@ from tensorflow.python.summary import event_accumulator
|
||||
from tensorflow.python.summary import event_multiplexer
|
||||
from tensorflow.tensorboard import tensorboard_handler
|
||||
|
||||
flags.DEFINE_string('logdir', None, """
|
||||
logdir specifies where TensorBoard will look to find TensorFlow event files
|
||||
that it can display. In the simplest case, logdir is a directory containing
|
||||
tfevents files. TensorBoard also supports comparing multiple TensorFlow
|
||||
executions: to do this, you can use directory whose subdirectories contain
|
||||
tfevents files, as in the following example:
|
||||
flags.DEFINE_string('logdir', None, """logdir specifies the directory where
|
||||
TensorBoard will look to find TensorFlow event files that it can display.
|
||||
TensorBoard will recursively walk the directory structure rooted at logdir,
|
||||
looking for .*tfevents.* files.
|
||||
|
||||
foo/bar/logdir/
|
||||
foo/bar/logdir/mnist_1/events.out.tfevents.1444088766
|
||||
foo/bar/logdir/mnist_2/events.out.tfevents.1444090064
|
||||
|
||||
You may also pass a comma seperated list of log directories, and you can
|
||||
assign names to individual log directories by putting a colon between the name
|
||||
and the path, as in
|
||||
You may also pass a comma seperated list of log directories, and TensorBoard
|
||||
will watch each directory. You can also assign names to individual log
|
||||
directories by putting a colon between the name and the path, as in
|
||||
|
||||
tensorboard --logdir=name1:/path/to/logs/1,name2:/path/to/logs/2
|
||||
""")
|
||||
|
||||
flags.DEFINE_boolean('debug', False, 'Whether to run the app in debug mode. '
|
||||
'This increases log verbosity to DEBUG.')
|
||||
|
||||
|
||||
flags.DEFINE_string('host', '127.0.0.1', 'What host to listen to. Defaults to '
|
||||
'serving on localhost, set to 0.0.0.0 for remote access.')
|
||||
|
||||
flags.DEFINE_integer('port', 6006, 'What port to serve TensorBoard on.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
@ -10,8 +10,10 @@ from setuptools.dist import Distribution
|
||||
_VERSION = '0.5.0'
|
||||
|
||||
REQUIRED_PACKAGES = [
|
||||
'numpy >= 1.9.2',
|
||||
'numpy >= 1.8.2',
|
||||
'six >= 1.10.0',
|
||||
'protobuf == 3.0.0a3',
|
||||
'wheel',
|
||||
]
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
|
Loading…
Reference in New Issue
Block a user