From 4213ac97be449d0e40631a314d2b7bd3901d4967 Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Mon, 16 Nov 2015 23:42:32 -0800 Subject: [PATCH] TensorFlow: conv improvements, label_image example, and a few other changes. Changes: - Some improvements to convolution by using 32-bit indices by @benoitsteiner. Not all calls converted yet. Also some improvements to pooling as well by @benoitsteiner. - Improvements to sparse matmul CPU implementation by Ashish - Some fixes to warnings by @vrv - Doc fixes to padding by @Yangqing - Some improvements to Tensor wrappers by Eider - Speed up of matrix inverse on CPU by Rasmus - Add an example of doing image inference from a pre-trained model by @petewarden. - fixed formula in mnist example by nodir - Updates to event accumulator by Cassandra - Slight changes to tensor c api by @mrry - Handling of strings in listdiff by Phil - Fix negative fraction-of-queue-full stats by Frank - Type-checking improvement to importer by Yaroslav - logdir recursive search for Tensorboard by @danmane - Session.run() checks for empty graph by Manoj Base CL: 108013706 --- tensorflow/core/client/tensor_c_api.cc | 6 +- .../gpu/gpu_bfc_allocator_test.cc | 10 +- tensorflow/core/framework/tensor_types.h | 77 +- tensorflow/core/kernels/conv_2d.h | 109 +- tensorflow/core/kernels/conv_grad_ops.cc | 158 +-- tensorflow/core/kernels/conv_ops.cc | 33 +- tensorflow/core/kernels/conv_ops_gpu_2.cu.cc | 6 +- tensorflow/core/kernels/conv_ops_gpu_3.cu.cc | 13 +- tensorflow/core/kernels/linalg_ops_common.h | 9 +- tensorflow/core/kernels/listdiff_op.cc | 2 + tensorflow/core/kernels/matrix_inverse_op.cc | 33 +- tensorflow/core/kernels/pooling_ops_common.cc | 22 +- tensorflow/core/kernels/sparse_matmul_op.cc | 888 +++++++++++++-- .../core/kernels/sparse_matmul_op_test.cc | 60 +- .../core/kernels/string_to_hash_bucket_op.cc | 2 +- .../core/lib/core/command_line_flags.cc | 14 +- tensorflow/core/lib/core/command_line_flags.h | 5 +- tensorflow/core/ops/array_ops.cc | 4 +- tensorflow/core/ops/linalg_ops.cc | 18 +- tensorflow/core/public/tensor_c_api.h | 4 +- tensorflow/examples/android/jni/jni_utils.cc | 8 +- tensorflow/examples/label_image/BUILD | 30 + tensorflow/examples/label_image/README.md | 49 + .../label_image/data/googlenet_labels.txt | 1001 +++++++++++++++++ .../label_image/data/grace_hopper.jpg | Bin 0 -> 61306 bytes tensorflow/examples/label_image/main.cc | 295 +++++ tensorflow/g3doc/api_docs/index.md | 2 + tensorflow/g3doc/api_docs/python/nn.md | 59 +- .../g3doc/tutorials/mnist/beginners/index.md | 2 +- tensorflow/python/BUILD | 1 + tensorflow/python/client/session.py | 3 + tensorflow/python/client/session_test.py | 6 + tensorflow/python/client/tf_session.i | 2 +- tensorflow/python/framework/importer.py | 9 +- tensorflow/python/framework/ops.py | 1 + tensorflow/python/framework/ops_test.py | 14 + .../python/kernel_tests/conv_ops_test.py | 8 + .../python/kernel_tests/listdiff_op_test.py | 56 +- .../kernel_tests/matrix_inverse_op_test.py | 28 +- tensorflow/python/ops/clip_ops.py | 3 +- tensorflow/python/ops/common_shapes.py | 12 +- tensorflow/python/ops/logging_ops.py | 3 +- tensorflow/python/ops/nn.py | 59 +- .../python/summary/event_accumulator.py | 27 + .../python/summary/event_accumulator_test.py | 62 +- .../python/summary/event_multiplexer.py | 35 +- .../python/summary/event_multiplexer_test.py | 92 +- tensorflow/python/summary/impl/reservoir.py | 63 +- .../python/summary/impl/reservoir_test.py | 27 + tensorflow/python/training/input.py | 6 +- .../components/tf-event-dashboard/tf-chart.ts | 2 +- .../components/tf-graph-common/lib/graph.ts | 6 +- .../tf-graph-common/lib/hierarchy.ts | 6 +- .../components/tf-graph-common/lib/render.ts | 2 +- .../tf-graph-common/lib/template.ts | 5 +- .../tf-graph-common/tf-graph-common.html | 1 + .../components/tf-graph/tf-graph-scene.html | 3 +- tensorflow/tensorboard/tensorboard.py | 24 +- tensorflow/tools/pip_package/setup.py | 4 +- 59 files changed, 2976 insertions(+), 513 deletions(-) create mode 100644 tensorflow/examples/label_image/BUILD create mode 100644 tensorflow/examples/label_image/README.md create mode 100644 tensorflow/examples/label_image/data/googlenet_labels.txt create mode 100644 tensorflow/examples/label_image/data/grace_hopper.jpg create mode 100644 tensorflow/examples/label_image/main.cc diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc index 59cf0ed8f98..5dc27f73608 100644 --- a/tensorflow/core/client/tensor_c_api.cc +++ b/tensorflow/core/client/tensor_c_api.cc @@ -141,9 +141,9 @@ void TF_SetTarget(TF_SessionOptions* options, const char* target) { options->options.target = target; } -void TF_SetConfig(TF_SessionOptions* options, const char* config, - size_t config_len, TF_Status* status) { - if (!options->options.config.ParseFromArray(config, config_len)) { +void TF_SetConfig(TF_SessionOptions* options, const void* proto, + size_t proto_len, TF_Status* status) { + if (!options->options.config.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument("Unparseable ConfigProto"); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index 7b5e8aec1dc..cdfb06f72c2 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -30,7 +30,7 @@ TEST(GPUBFCAllocatorTest, NoDups) { std::sort(ptrs.begin(), ptrs.end()); // Make sure none of them are equal, and that none of them overlap. - for (int i = 0; i < ptrs.size(); i++) { + for (size_t i = 0; i < ptrs.size(); i++) { if (i > 0) { ASSERT_NE(ptrs[i], ptrs[i - 1]); // No dups size_t req_size = a.RequestedSize(ptrs[i - 1]); @@ -40,7 +40,7 @@ TEST(GPUBFCAllocatorTest, NoDups) { } } - for (int i = 0; i < ptrs.size(); i++) { + for (size_t i = 0; i < ptrs.size(); i++) { a.DeallocateRaw(ptrs[i]); } } @@ -63,7 +63,7 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) { // Deallocate half of the memory, and keep track of the others. std::vector existing_ptrs; - for (int i = 0; i < initial_ptrs.size(); i++) { + for (size_t i = 0; i < initial_ptrs.size(); i++) { if (i % 2 == 1) { a.DeallocateRaw(initial_ptrs[i]); } else { @@ -81,7 +81,7 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) { std::sort(existing_ptrs.begin(), existing_ptrs.end()); // Make sure none of them are equal - for (int i = 0; i < existing_ptrs.size(); i++) { + for (size_t i = 0; i < existing_ptrs.size(); i++) { if (i > 0) { CHECK_NE(existing_ptrs[i], existing_ptrs[i - 1]); // No dups @@ -95,7 +95,7 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) { } } - for (int i = 0; i < existing_ptrs.size(); i++) { + for (size_t i = 0; i < existing_ptrs.size(); i++) { a.DeallocateRaw(existing_ptrs[i]); } } diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h index 077d86d4420..c0221f9b529 100644 --- a/tensorflow/core/framework/tensor_types.h +++ b/tensorflow/core/framework/tensor_types.h @@ -6,66 +6,73 @@ namespace tensorflow { // Helper to define Tensor types given that the scalar is of type T. -template +template struct TTypes { // Rank- tensor of scalar type T. - typedef Eigen::TensorMap, + typedef Eigen::TensorMap, Eigen::Aligned> Tensor; - typedef Eigen::TensorMap, - Eigen::Aligned> ConstTensor; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstTensor; // Unaligned Rank- tensor of scalar type T. - typedef Eigen::TensorMap > + typedef Eigen::TensorMap > UnalignedTensor; - typedef Eigen::TensorMap > - UnalignedConstTensor; + typedef Eigen::TensorMap > UnalignedConstTensor; typedef Eigen::TensorMap, Eigen::Aligned> Tensor32Bit; // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. typedef Eigen::TensorMap< - Eigen::TensorFixedSize, Eigen::RowMajor>, + Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, Eigen::Aligned> Scalar; - typedef Eigen::TensorMap< - Eigen::TensorFixedSize, Eigen::RowMajor>, - Eigen::Aligned> ConstScalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType>, + Eigen::Aligned> ConstScalar; // Unaligned Scalar tensor of scalar type T. typedef Eigen::TensorMap, Eigen::RowMajor> > UnalignedScalar; - typedef Eigen::TensorMap, Eigen::RowMajor> > UnalignedConstScalar; + T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> > UnalignedScalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType> > + UnalignedConstScalar; // Rank-1 tensor (vector) of scalar type T. - typedef Eigen::TensorMap, Eigen::Aligned> - Flat; - typedef Eigen::TensorMap, - Eigen::Aligned> ConstFlat; - typedef Eigen::TensorMap, Eigen::Aligned> - Vec; - typedef Eigen::TensorMap, - Eigen::Aligned> ConstVec; + typedef Eigen::TensorMap, + Eigen::Aligned> Flat; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstFlat; + typedef Eigen::TensorMap, + Eigen::Aligned> Vec; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstVec; // Unaligned Rank-1 tensor (vector) of scalar type T. - typedef Eigen::TensorMap > UnalignedFlat; - typedef Eigen::TensorMap > - UnalignedConstFlat; - typedef Eigen::TensorMap > UnalignedVec; - typedef Eigen::TensorMap > - UnalignedConstVec; + typedef Eigen::TensorMap > + UnalignedFlat; + typedef Eigen::TensorMap > UnalignedConstFlat; + typedef Eigen::TensorMap > + UnalignedVec; + typedef Eigen::TensorMap< + Eigen::Tensor > UnalignedConstVec; // Rank-2 tensor (matrix) of scalar type T. - typedef Eigen::TensorMap, Eigen::Aligned> - Matrix; - typedef Eigen::TensorMap, - Eigen::Aligned> ConstMatrix; + typedef Eigen::TensorMap, + Eigen::Aligned> Matrix; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstMatrix; // Unaligned Rank-2 tensor (matrix) of scalar type T. - typedef Eigen::TensorMap > + typedef Eigen::TensorMap > UnalignedMatrix; - typedef Eigen::TensorMap > - UnalignedConstMatrix; + typedef Eigen::TensorMap > UnalignedConstMatrix; }; typedef typename TTypes::Tensor32Bit::Index Index32; diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index e4ea02b7bc7..8313e6e6703 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -11,24 +11,25 @@ namespace functor { // TODO(yangke): revisit these operations and in particular, see if we can // combine all of them into just one operation without causing nvcc to // timeout. -template +template struct ShuffleAndReverse { - void operator()(const Device& d, typename TTypes::ConstTensor input, - const Eigen::DSizes& order, + void operator()(const Device& d, + typename TTypes::ConstTensor input, + const Eigen::DSizes& order, const Eigen::array& reverse_dims, - typename TTypes::Tensor output) { + typename TTypes::Tensor output) { output.device(d) = input.shuffle(order).reverse(reverse_dims); } }; -template +template struct InflatePadAndShuffle { void operator()( - const Device& d, typename TTypes::ConstTensor input, - const Eigen::DSizes& strides, - const Eigen::array, Dims>& pad_dims, - const Eigen::DSizes& order, - typename TTypes::Tensor output) { + const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + const Eigen::array, Dims>& pad_dims, + const Eigen::DSizes& order, + typename TTypes::Tensor output) { output.device(d) = input.inflate(strides).pad(pad_dims).shuffle(order); } }; @@ -89,30 +90,92 @@ struct MatMulConvFunctor { } }; -template +template struct TransformFilter { - void operator()(const Device& d, typename TTypes::ConstTensor in, - typename TTypes::Tensor out) { - out.device(d) = in.shuffle(Eigen::DSizes(3, 2, 0, 1)); + void operator()(const Device& d, + typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { + // We want a 3, 2, 0, 1 shuffle. We can merge dimensions 0 and 1 together + // to help speedup the shuffle operation. + Eigen::DSizes merged_dims; + merged_dims[0] = in.dimension(0) * in.dimension(1); + merged_dims[1] = in.dimension(2); + merged_dims[2] = in.dimension(3); + + Eigen::DSizes expanded_dims; + expanded_dims[0] = in.dimension(3); + expanded_dims[1] = in.dimension(2); + expanded_dims[2] = in.dimension(0); + expanded_dims[3] = in.dimension(1); + + out.device(d) = in.reshape(merged_dims) + .shuffle(Eigen::DSizes(2, 1, 0)) + .reshape(expanded_dims); } }; -template +template struct TransformDepth { - void operator()(const Device& d, typename TTypes::ConstTensor in, - const Eigen::DSizes& shuffle, - typename TTypes::Tensor out) { - out.device(d) = in.shuffle(shuffle); + void operator()(const Device& d, + typename TTypes::ConstTensor in, + const Eigen::DSizes& shuffle, + typename TTypes::Tensor out) { + Eigen::DSizes merged_dims; + Eigen::DSizes expanded_dims; + Eigen::DSizes new_shuffle; + + // Merge dimensions that won't be shuffled together to speed things up. + if (shuffle[1] == 2 && shuffle[2] == 3) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1); + merged_dims[2] = in.dimension(2) * in.dimension(3); + new_shuffle[0] = shuffle[0]; + new_shuffle[1] = 2; + new_shuffle[2] = shuffle[3]; + expanded_dims[0] = in.dimension(shuffle[0]); + expanded_dims[1] = in.dimension(2); + expanded_dims[2] = in.dimension(3); + expanded_dims[3] = in.dimension(shuffle[3]); + } else if (shuffle[0] == 2 && shuffle[1] == 3) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1); + merged_dims[2] = in.dimension(2) * in.dimension(3); + new_shuffle[0] = 2; + new_shuffle[1] = shuffle[2]; + new_shuffle[2] = shuffle[3]; + expanded_dims[0] = in.dimension(2); + expanded_dims[1] = in.dimension(3); + expanded_dims[2] = in.dimension(shuffle[2]); + expanded_dims[3] = in.dimension(shuffle[3]); + } else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 && + shuffle[3] == 2) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1) * in.dimension(2); + merged_dims[2] = in.dimension(3); + new_shuffle[0] = 0; + new_shuffle[1] = 2; + new_shuffle[2] = 1; + expanded_dims[0] = in.dimension(0); + expanded_dims[1] = in.dimension(3); + expanded_dims[2] = in.dimension(1); + expanded_dims[3] = in.dimension(2); + } else { + assert(false && "unexpected shuffle"); + } + + out.device(d) = + in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims); } }; -template +template struct PadInput { - void operator()(const Device& d, typename TTypes::ConstTensor in, + void operator()(const Device& d, + typename TTypes::ConstTensor in, int padding_rows_left, int padding_rows_right, int padding_cols_left, int padding_cols_right, - typename TTypes::Tensor out) { - Eigen::array, 4> padding; + typename TTypes::Tensor out) { + Eigen::array, 4> padding; padding[0] = std::make_pair(0, 0); padding[1] = std::make_pair(padding_rows_left, padding_rows_right); padding[2] = std::make_pair(padding_cols_left, padding_cols_right); diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index 16f4d55477c..98772e4e63d 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -783,9 +783,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel { TensorShape({out_depth, in_depth, filter_rows, filter_cols}), &transformed_filter)); - functor::TransformFilter()(context->eigen_device(), - filter.tensor(), - transformed_filter.tensor()); + functor::TransformFilter()( + context->eigen_device(), To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); Tensor transformed_out_backprop; OP_REQUIRES_OK( @@ -795,10 +795,10 @@ class Conv2DSlowBackpropInputOp : public OpKernel { TensorShape({batch, out_depth, output_rows, output_cols}), &transformed_out_backprop)); - functor::TransformDepth()( - context->eigen_device(), out_backprop.tensor(), - Eigen::DSizes(0, 3, 1, 2), - transformed_out_backprop.tensor()); + functor::TransformDepth()( + context->eigen_device(), To32Bit(out_backprop.tensor()), + Eigen::DSizes(0, 3, 1, 2), + To32Bit(transformed_out_backprop.tensor())); Tensor pre_transformed_in_backprop; OP_REQUIRES_OK(context, @@ -831,11 +831,12 @@ class Conv2DSlowBackpropInputOp : public OpKernel { } auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::TransformDepth()( + functor::TransformDepth()( context->eigen_device(), - toConstTensor(pre_transformed_in_backprop).template tensor(), - Eigen::DSizes(0, 2, 3, 1), - in_backprop->tensor()); + To32Bit(toConstTensor(pre_transformed_in_backprop) + .template tensor()), + Eigen::DSizes(0, 2, 3, 1), + To32Bit(in_backprop->tensor())); } else { // We fill out a padded out_backprop TensorShape padded_out_shape( @@ -852,7 +853,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel { {left_pad_cols, right_pad_cols}, {0, 0}}}; - functor::InflatePadAndShuffle()( + functor::InflatePadAndShuffle()( context->eigen_device(), out_backprop.tensor(), strides, pad_dims, trivial_order, padded_output.tensor()); const Tensor& padded_output_cref = padded_output; @@ -869,7 +870,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel { Eigen::DSizes filter_order{0, 1, 3, 2}; Eigen::array filter_rev_dims{true, true, false, false}; - functor::ShuffleAndReverse()( + functor::ShuffleAndReverse()( context->eigen_device(), filter.tensor(), filter_order, filter_rev_dims, r_filter.tensor()); const Tensor& r_filter_cref = r_filter; @@ -1033,10 +1034,10 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { TensorShape({batch, out_depth, output_rows, output_cols}), &transformed_out_backprop)); - functor::TransformDepth()( - context->eigen_device(), out_backprop.tensor(), - Eigen::DSizes(0, 3, 1, 2), - transformed_out_backprop.tensor()); + functor::TransformDepth()( + context->eigen_device(), To32Bit(out_backprop.tensor()), + Eigen::DSizes(0, 3, 1, 2), + To32Bit(transformed_out_backprop.tensor())); Tensor transformed_input; OP_REQUIRES_OK(context, @@ -1045,10 +1046,10 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { TensorShape({batch, in_depth, input_rows, input_cols}), &transformed_input)); - functor::TransformDepth()( - context->eigen_device(), input.tensor(), - Eigen::DSizes(0, 3, 1, 2), - transformed_input.tensor()); + functor::TransformDepth()( + context->eigen_device(), To32Bit(input.tensor()), + Eigen::DSizes(0, 3, 1, 2), + To32Bit(transformed_input.tensor())); auto out_backprop_ptr = AsDeviceMemory(transformed_out_backprop.template flat().data(), @@ -1074,12 +1075,12 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { } auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::TransformDepth()( + functor::TransformDepth()( context->eigen_device(), - toConstTensor(pre_transformed_filter_backprop) - .template tensor(), - Eigen::DSizes(2, 3, 1, 0), - filter_backprop->tensor()); + To32Bit(toConstTensor(pre_transformed_filter_backprop) + .template tensor()), + Eigen::DSizes(2, 3, 1, 0), + To32Bit(filter_backprop->tensor())); } else { // Fall back to the non-cudnn code path @@ -1102,7 +1103,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { {top_pad_rows, bottom_pad_rows}, {left_pad_cols, right_pad_cols}, {0, 0}}}; - functor::InflatePadAndShuffle()( + functor::InflatePadAndShuffle()( context->eigen_device(), out_backprop.tensor(), strides, pad_dims, out_order, padded_output.tensor()); const Tensor& padded_output_cref = padded_output; @@ -1121,7 +1122,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // No need for reversing this time. Eigen::array trivial_dims{false, false, false, false}; - functor::ShuffleAndReverse()( + functor::ShuffleAndReverse()( context->eigen_device(), input.tensor(), in_order, trivial_dims, in_shuffle.tensor()); const Tensor& in_shuffle_cref = in_shuffle; @@ -1149,7 +1150,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { Eigen::DSizes filter_order{1, 2, 3, 0}; Eigen::array filter_rev_dims{true, true, false, false}; const Tensor& filter_shuffle_cref = filter_shuffle; - functor::ShuffleAndReverse()( + functor::ShuffleAndReverse()( context->eigen_device(), filter_shuffle_cref.tensor(), filter_order, filter_rev_dims, filter_backprop->tensor()); } @@ -1165,46 +1166,65 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void ShuffleAndReverse::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - const Eigen::DSizes& order, \ - const Eigen::array& reverse_dims, \ - typename TTypes::Tensor output); \ - extern template struct ShuffleAndReverse; \ - template <> \ - void InflatePadAndShuffle::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - const Eigen::DSizes& strides, \ - const Eigen::array, 4>& pad_dims, \ - const Eigen::DSizes& order, \ - typename TTypes::Tensor output); \ - extern template struct InflatePadAndShuffle; \ - template <> \ - void TransformFilter::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor in, \ - typename TTypes::Tensor out); \ - extern template struct TransformFilter; \ - template <> \ - void TransformDepth::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor in, \ - const Eigen::DSizes& shuffle, \ - typename TTypes::Tensor out); \ - extern template struct TransformDepth; \ - template <> \ - void SpatialConvolution::operator()( \ - const GPUDevice& d, typename TTypes::Tensor output, \ - typename TTypes::ConstTensor input, \ - typename TTypes::ConstTensor filter, int stride, \ - const Eigen::PaddingType& padding); \ - extern template struct SpatialConvolution; \ - template <> \ - void SpatialConvolutionBackwardInput::operator()( \ - const GPUDevice& d, typename TTypes::Tensor in_backprop, \ - typename TTypes::ConstTensor filter, \ - typename TTypes::ConstTensor output_backprop, int input_rows, \ - int input_cols, int stride); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ShuffleAndReverse::operator()( \ + const GPUDevice& d, \ + typename TTypes::ConstTensor input, \ + const Eigen::DSizes& order, \ + const Eigen::array& reverse_dims, \ + typename TTypes::Tensor output); \ + extern template struct ShuffleAndReverse; \ + template <> \ + void InflatePadAndShuffle::operator()( \ + const GPUDevice& d, \ + typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + const Eigen::array, 4>& pad_dims, \ + const Eigen::DSizes& order, \ + typename TTypes::Tensor output); \ + extern template struct InflatePadAndShuffle; \ + template <> \ + void ShuffleAndReverse::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& order, \ + const Eigen::array& reverse_dims, \ + typename TTypes::Tensor output); \ + extern template struct ShuffleAndReverse; \ + template <> \ + void InflatePadAndShuffle::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + const Eigen::array, 4>& pad_dims, \ + const Eigen::DSizes& order, \ + typename TTypes::Tensor output); \ + extern template struct InflatePadAndShuffle; \ + template <> \ + void TransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + extern template struct TransformFilter; \ + template <> \ + void TransformDepth::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const Eigen::DSizes& shuffle, \ + typename TTypes::Tensor out); \ + extern template struct TransformDepth; \ + template <> \ + void SpatialConvolution::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + typename TTypes::ConstTensor filter, int stride, \ + const Eigen::PaddingType& padding); \ + extern template struct SpatialConvolution; \ + template <> \ + void SpatialConvolutionBackwardInput::operator()( \ + const GPUDevice& d, typename TTypes::Tensor in_backprop, \ + typename TTypes::ConstTensor filter, \ + typename TTypes::ConstTensor output_backprop, int input_rows, \ + int input_cols, int stride); \ extern template struct SpatialConvolutionBackwardInput DECLARE_GPU_SPEC(float); diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 0ca5afd9437..a98bcc367d9 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -167,6 +167,10 @@ class Conv2DOp : public BinaryOp { << ", filter_rows = " << filter_rows << ", stride = " << stride << ", out_depth = " << out_depth; + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } LaunchConvOp::launch(context, use_cudnn_, input, filter, stride, BrainPadding2EigenPadding(padding_), output); @@ -260,10 +264,11 @@ struct LaunchConvOp { input.dim_size(2) + padding_cols, input.dim_size(3)}), &transformed_input)); - functor::PadInput()( - ctx->eigen_device(), input_param.tensor(), + functor::PadInput()( + ctx->eigen_device(), To32Bit(input_param.tensor()), padding_rows / 2, padding_rows - padding_rows / 2, padding_cols / 2, - padding_cols - padding_cols / 2, transformed_input.tensor()); + padding_cols - padding_cols / 2, + To32Bit(transformed_input.tensor())); input = transformed_input; } @@ -296,9 +301,9 @@ struct LaunchConvOp { filter.dim_size(0), filter.dim_size(1)}), &transformed_filter)); - functor::TransformFilter()( - ctx->eigen_device(), filter.tensor(), - transformed_filter.tensor()); + functor::TransformFilter()( + ctx->eigen_device(), To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); auto input_ptr = AsDeviceMemory(input.template flat().data(), input.template flat().size()); @@ -346,16 +351,16 @@ namespace functor { const Eigen::array, 1>& dim_pair); \ extern template struct MatMulConvFunctor; \ template <> \ - void TransformFilter::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor in, \ - typename TTypes::Tensor out); \ - extern template struct TransformFilter; \ + void TransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + extern template struct TransformFilter; \ template <> \ - void PadInput::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor in, \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ int padding_rows_left, int padding_rows_right, int padding_cols_left, \ - int padding_cols_right, typename TTypes::Tensor out); \ - extern template struct PadInput + int padding_cols_right, typename TTypes::Tensor out); \ + extern template struct PadInput DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC diff --git a/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc index e2e9d25d839..6ea20278997 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc @@ -5,12 +5,14 @@ #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/public/tensor.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -template struct functor::InflatePadAndShuffle; - +template struct functor::InflatePadAndShuffle; +template struct functor::InflatePadAndShuffle; } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index dbbe08ef9c0..77d3a68f340 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -9,13 +9,18 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -template struct functor::ShuffleAndReverse; +template struct functor::ShuffleAndReverse; +template struct functor::ShuffleAndReverse; -template struct functor::TransformFilter; +template struct functor::TransformFilter; -template struct functor::PadInput; +template struct functor::PadInput; -template struct functor::TransformDepth; +template struct functor::TransformDepth; +// TODO(jiayq): currently pooling ops still use DenseIndex, so I am keeping it +// here. +template struct functor::TransformDepth; } // namespace tensorflow diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h index adc4734523e..80d485a8aa7 100644 --- a/tensorflow/core/kernels/linalg_ops_common.h +++ b/tensorflow/core/kernels/linalg_ops_common.h @@ -86,11 +86,10 @@ class LinearAlgebraOp : public LinearAlgebraOpBase { explicit LinearAlgebraOp(OpKernelConstruction* context) : LinearAlgebraOpBase(context) {} - using ConstMatrixMap = - Eigen::Map>; - using MatrixMap = Eigen::Map< - Eigen::Matrix>; + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using MatrixMap = Eigen::Map; // Perform the actual computation on the input matrix, and store the results // in the output. This will be called repeatedly for a single call to diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc index 2534d3ce662..bc3d6c49128 100644 --- a/tensorflow/core/kernels/listdiff_op.cc +++ b/tensorflow/core/kernels/listdiff_op.cc @@ -1,3 +1,4 @@ +#include #include #include @@ -70,6 +71,7 @@ class ListDiffOp : public OpKernel { ListDiffOp) TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF); +REGISTER_LISTDIFF(string); #undef REGISTER_LISTDIFF } // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_inverse_op.cc b/tensorflow/core/kernels/matrix_inverse_op.cc index 7af4aaa3e6e..6fe763aae97 100644 --- a/tensorflow/core/kernels/matrix_inverse_op.cc +++ b/tensorflow/core/kernels/matrix_inverse_op.cc @@ -1,6 +1,7 @@ // See docs in ../ops/linalg_ops.cc. #include +#include "third_party/eigen3/Eigen/Cholesky" #include "third_party/eigen3/Eigen/LU" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -35,6 +36,7 @@ class MatrixInverseOp } } + using typename LinearAlgebraOp::Matrix; using typename LinearAlgebraOp::MatrixMap; using typename LinearAlgebraOp::ConstMatrixMap; @@ -44,15 +46,36 @@ class MatrixInverseOp OP_REQUIRES(context, input.rows() == input.cols(), errors::InvalidArgument("Input matrix must be square.")); if (input.rows() == 0) { - // By definition, an empty matrix's inverse is an emptry matrix. + // By definition, an empty matrix's inverse is an empty matrix. return; } - Eigen::FullPivLU> lu_decomposition(input); - OP_REQUIRES(context, lu_decomposition.isInvertible(), + if (input.isApprox(input.transpose())) { + // Matrix is symmetric, compute Cholesky factorization + // input = L * L^T. + Eigen::LLT cholesky_decomposition(input); + if (cholesky_decomposition.info() == Eigen::Success) { + // Cholesky succeeded => Matrix was SPD. + output->noalias() = cholesky_decomposition.solve( + Matrix::Identity(input.rows(), input.cols())); + return; + } + } + Eigen::PartialPivLU lu_decomposition(input); + // While PartialPivLU cannot give strong guarantees on invertability, + // we can at least guard against exact zero pivots. This can occur as + // a result of basic user mistakes, such as providing integer valued + // matrices that are exacly singular, or due to underflow if this + // code is run with denormals being flushed to zero. + // TODO(rmlarsen): Add check based on condition number estimation. + const Scalar min_abs_pivot = + lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); + OP_REQUIRES(context, min_abs_pivot > Scalar(0), errors::InvalidArgument("Input is not invertible.")); - *output = lu_decomposition.inverse(); + output->noalias() = lu_decomposition.inverse(); } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixInverseOp); }; REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp), float); diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index edb1b3c288d..2f4ecb48912 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -99,13 +99,13 @@ perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory, // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void TransformDepth::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor in, \ - const Eigen::DSizes& shuffle, \ - typename TTypes::Tensor out); \ - extern template struct TransformDepth; +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void TransformDepth::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const Eigen::DSizes& shuffle, \ + typename TTypes::Tensor out); \ + extern template struct TransformDepth; DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC @@ -172,7 +172,7 @@ void DnnPoolingGradOp::Compute( // For AvgPoolGrad, the original input tensor is not necessary. However, // cudnn still requires them to run, although they do not affect the // results. - functor::TransformDepth()( + functor::TransformDepth()( context->eigen_device(), tensor_in->tensor(), nhwc_to_nchw, transformed_input.tensor()); } @@ -180,11 +180,11 @@ void DnnPoolingGradOp::Compute( // For AvgPoolGrad, the original output tensor is not necessary. However, // cudnn still requires them to run, although they do not affect the // results. - functor::TransformDepth()( + functor::TransformDepth()( context->eigen_device(), tensor_out->tensor(), nhwc_to_nchw, transformed_output.tensor()); } - functor::TransformDepth()( + functor::TransformDepth()( context->eigen_device(), out_backprop.tensor(), nhwc_to_nchw, transformed_output_backprop.tensor()); @@ -239,7 +239,7 @@ void DnnPoolingGradOp::Compute( /// Transform the output data from NCHW back to NHWC auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; auto nchw_to_nhwc = Eigen::DSizes(0, 2, 3, 1); - functor::TransformDepth()( + functor::TransformDepth()( context->eigen_device(), toConstTensor(transformed_input_backprop).template tensor(), nchw_to_nhwc, output->tensor()); diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 919e129ff8c..de3e0eea763 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -3,71 +3,427 @@ #define EIGEN_USE_THREADS #include "tensorflow/core/common_runtime/device.h" +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/port.h" - +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/work_sharder.h" +#include "tensorflow/core/platform/port.h" namespace tensorflow { +namespace { + +typedef Eigen::Tensor Matrix; +typedef Eigen::DSizes DSizes; +typedef Eigen::TensorMap, + Eigen::Aligned> MatrixMap; +typedef Eigen::TensorMap, + Eigen::Aligned> ConstMatrixMap; typedef Eigen::ThreadPoolDevice CPUDevice; -template -void PrefetchBlockNTA(const T& tensor, int si, int ei, int sj, int ej) { - for (int i = si; i < ei; ++i) { - for (int j = sj; j < ej; j = j + 16) { - port::prefetch(&tensor(i, j)); - } +// Blocksizes +// TODO(agarwal): compute these sizes based on cache sizes. +static const int K = 64; +static const int M = 64; +static const int N = 128; + +// This stores a sparse representation of a slice of a matrix with size +// (num_rows, num_cols). The slice is represented as a series of blocks of size +// (num_rows, b), where b = block_size for all but the last block, which may +// have +// fewer columns. +// +// num_rows and block_size are assumed to be <= 256. This allows storing +// different indices as uint8. +// +// For each block, we store all the non zero entries in data/data3 vector and +// the corresponding coordinates of the element in index/index3 vectors. index3 +// vector stores index of 3 elements in the same row so that these elements can +// share the same row coordinate. Each entry in Index3 corresponds to 3 entries +// in data3. +// +// Note that all the data/indices of all the blocks are stored in the same +// vectors respectively. To identify block boundaries, we store the block +// offsets using index3_offset/index_offset. If there are n blocks in the slice, +// index3_offset and index_offset have n entires. The indices for the ith block +// are the values in the following range: +// [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for +// index_offset. +struct SparseSlice { + public: + // Indices of three elements on the same row. + struct Index3 { + uint8 m; // row + // columns + uint8 k1; + uint8 k2; + uint8 k3; + }; + + // Index of one element. + struct Index { + uint8 m; + uint8 k; + }; + + SparseSlice(int nrows, int ncols, int bsize) + : num_rows(nrows), num_cols(ncols), block_size(bsize) { + DCHECK_LE(nrows, 256); + DCHECK_LE(block_size, 256); } -} -template -void PrefetchBlockT1(const T& tensor, int si, int ei, int sj, int ej) { - for (int i = si; i < ei; ++i) { - for (int j = sj; j < ej; j = j + 16) { - port::prefetch(&tensor(i, j)); - } - } -} + // Initializes the slice with data starting at mat(0, col_offset) and with + // size (num_rows, num_cols). + // If Transpose is true, implicitly transposes mat. + template + void Initialize(const ConstMatrixMap& mat, int col_offset); -struct Block { - Block(int sm, int em, int sk, int ek, int sn, int en) - : startm(sm), endm(em), startk(sk), endk(ek), startn(sn), endn(en) {} + void Clear(); - int startm; - int endm; - int startk; - int endk; - int startn; - int endn; + // See comments above. + std::vector index3_offset; + std::vector index3; + std::vector data3; + + // See comments above. Similar to "index3" except that each element in "index" + // corresponds to one element in data. + std::vector index_offset; + std::vector index; + std::vector data; + + // Number of rows and columns for the slice. + const int num_rows; + const int num_cols; + + // Block size used to initialize from a matrix. + const int block_size; }; -bool NextBlock(const int Bm, const int Bk, const int Bn, const int m_start, - const int m, const int k, const int n, const Block& b, - Block* next) { - *next = b; - if (b.endk < k) { - next->startk = b.endk; - next->endk = std::min(b.endk + Bk, k); - } else { - next->startk = 0; - next->endk = std::min(Bk, k); - if (b.endm < m) { - next->startm = b.endm; - next->endm = std::min(b.endm + Bm, m); - } else { - next->startm = m_start; - next->endm = std::min(m_start + Bm, m); - next->startn = b.endn; - next->endn = std::min(b.endn + Bn, n); +template +void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { + const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0); + const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1); + DCHECK_LE(num_rows, mat_rows); + DCHECK_LE(num_cols + col_offset, mat_cols); + + int num_blocks = (num_cols + block_size - 1) / block_size; + int mat_size = num_rows * num_cols; + + index3_offset.reserve(num_blocks); + data3.reserve(mat_size); + index3.reserve(mat_size / 3); + + index_offset.reserve(num_blocks); + data.reserve(num_blocks * num_rows * 2); + index.reserve(num_blocks * num_rows * 2); + + Index3 idx3; + Index idx; + int data3_size = 0; + for (int i = 0; i < num_blocks; ++i) { + int num_block_cols = + std::min(block_size, num_cols - block_size * (num_blocks - 1)); + for (int row = 0; row < num_rows; ++row) { + idx3.m = static_cast(row); + const float* start = + Transpose ? &mat(col_offset, row) : &mat(row, col_offset); + const float* curr = start; + const int stride = Transpose ? mat.dimension(1) : 1; + const float* end = start + stride * num_block_cols; + uint8 k = 0; +#define NEXT_ELEM \ + curr += stride; \ + ++k; + while (true) { + while (curr < end && (*curr == 0)) { + NEXT_ELEM; + } + if (curr >= end) break; + idx3.k1 = k; + data3.push_back(*curr); + NEXT_ELEM; + + while (curr < end && (*curr == 0)) { + NEXT_ELEM; + } + if (curr >= end) break; + idx3.k2 = k; + data3.push_back(*curr); + NEXT_ELEM; + + while (curr < end && (*curr == 0)) { + NEXT_ELEM; + } + if (curr >= end) break; + idx3.k3 = k; + data3.push_back(*curr); + NEXT_ELEM; + index3.push_back(idx3); +#undef NEXT_ELEM + } + int num_inserted_mod = data3.size() % 3; + // Move some elements to index and data if needed. + data3_size = data3.size() - num_inserted_mod; + idx.m = idx3.m; + switch (num_inserted_mod) { + case 2: + idx.k = idx3.k2; + data.push_back(data3[data3_size + 1]); + index.push_back(idx); + TF_FALLTHROUGH_INTENDED; + case 1: + idx.k = idx3.k1; + data.push_back(data3[data3_size]); + index.push_back(idx); + data3.resize(data3_size); + } + } + col_offset += block_size; + index3_offset.push_back(index3.size()); + index_offset.push_back(index.size()); + } + DCHECK_EQ(index3_offset.size(), num_blocks); + DCHECK_EQ(index_offset.size(), num_blocks); + DCHECK_EQ(3 * index3.size(), data3.size()); + DCHECK_EQ(index.size(), data.size()); +} + +void SparseSlice::Clear() { + index3_offset.clear(); + index3.clear(); + data3.clear(); + index_offset.clear(); + index.clear(); + data.clear(); +} + +#define SCALAR_MULADD(a, inp, out) *out++ += *a * *inp++; + +#define SCALAR_MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \ + *out++ += *a1 * *inp1++ + *a2 * *inp2++ + *a3 * *inp3++; + +typedef Eigen::internal::packet_traits::type Packet; +static const int kNumOperands = (sizeof(Packet) / sizeof(float)); +#define LOAD(x) Eigen::internal::pload(x); +#define STORE(x, y) Eigen::internal::pstore(x, y); +#define LOAD_SCALAR(x, y) const auto y = Eigen::internal::pload1(x); +#define FMA(a, b, c, d) d = Eigen::internal::pmadd(a, b, c); + +// Vectorized version of SCALAR_MULADD. +#define MULADD(a, inp, out) \ + do { \ + const auto b = LOAD(inp); \ + inp += kNumOperands; \ + auto c = LOAD(out); \ + FMA(a, b, c, c); \ + STORE(out, c); \ + out += kNumOperands; \ + } while (false) + +// Vectorized version of SCALAR_MULADD3WAY. +#define MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \ + do { \ + auto c = LOAD(out); \ + const auto b1 = LOAD(inp1); \ + inp1 += kNumOperands; \ + const auto b2 = LOAD(inp2); \ + inp2 += kNumOperands; \ + const auto b3 = LOAD(inp3); \ + inp3 += kNumOperands; \ + FMA(a1, b1, c, c); \ + FMA(a2, b2, c, c); \ + FMA(a3, b3, c, c); \ + STORE(out, c); \ + out += kNumOperands; \ + } while (false) + +#ifdef EIGEN_VECTORIZE_AVX2 +// Unroll MULADD3WAY for two iterations +#define MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out) \ + do { \ + auto c1 = LOAD(out); \ + const auto b1 = LOAD(inp1); \ + const auto b2 = LOAD(inp2); \ + const auto b3 = LOAD(inp3); \ + \ + auto c2 = LOAD(out + kNumOperands); \ + const auto b4 = LOAD(inp1 + kNumOperands); \ + const auto b5 = LOAD(inp2 + kNumOperands); \ + const auto b6 = LOAD(inp3 + kNumOperands); \ + \ + FMA(a1, b1, c1, c1); \ + FMA(a1, b4, c2, c2); \ + FMA(a2, b2, c1, c1); \ + FMA(a2, b5, c2, c2); \ + FMA(a3, b3, c1, c1); \ + FMA(a3, b6, c2, c2); \ + STORE(out, c1); \ + STORE(out + kNumOperands, c2); \ + out += 2 * kNumOperands; \ + inp1 += 2 * kNumOperands; \ + inp2 += 2 * kNumOperands; \ + inp3 += 2 * kNumOperands; \ + } while (false) +// Further unroll MULADD3WAY. +#define MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out) \ + MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); \ + MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); +#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \ + MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ + MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ + MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ + MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); +#else +#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \ + for (int __i = 0; __i < 128 / (4 * kNumOperands); ++__i) { \ + MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ + MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ + MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ + MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ + } +#endif + +// Computes product of "left_slices" with "num_cols" columns of "right", and +// stores the output in *"output". +// Note that left_slices is a list of SparseSlices, which are conceptually +// assumed to be concatenated along the column dimension. Also each SparseSlice +// is encoded as a list of blocks with upto N columns. See SparseSlice for more +// details. +template +inline void GEPP(const std::vector& left_slices, + const ConstMatrixMap& right, const int num_cols, + Matrix* output) { + const int cols = (Cols == -1) ? num_cols : Cols; + DCHECK_EQ(num_cols, cols); + const int right_num_cols = right.dimension(1); + const int output_num_cols = output->dimension(1); + const int cols_mod = cols % kNumOperands; + int k_offset = 0; + // Pre-compute pointers for output matrix. + float* out_ptrs[M]; + float* const out_start = &(*output)(0, 0); + for (int j = 0; j < M; ++j) { + out_ptrs[j] = out_start + output_num_cols * j; + } + for (const auto* left_slice : left_slices) { + const auto& left = *left_slice; + const float* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr; + const float* data = (left.data.size() > 0) ? &left.data[0] : nullptr; + const int num_blocks = left.index3_offset.size(); + int begin3 = 0; + int begin = 0; + for (int i = 0; i < num_blocks; ++i) { + // Pre-compute pointers for right matrix + const float* right_ptrs[K]; + const float* const right_start = &right(k_offset, 0); + DCHECK_LT(k_offset, right.dimension(0)); + for (int j = 0; j < K; ++j) { + right_ptrs[j] = right_start + right_num_cols * j; + } + + const int end3 = left.index3_offset[i]; + int j = begin3; + // Loop unrolled for 2 iterations. + for (; j + 1 < end3; j += 2) { + const float* sl1 = data3++; + LOAD_SCALAR(sl1, l1); + const float* sl2 = data3++; + LOAD_SCALAR(sl2, l2); + const float* sl3 = data3++; + LOAD_SCALAR(sl3, l3); + const float* nsl1 = data3++; + LOAD_SCALAR(nsl1, nl1); + const float* nsl2 = data3++; + LOAD_SCALAR(nsl2, nl2); + const float* nsl3 = data3++; + LOAD_SCALAR(nsl3, nl3); + const SparseSlice::Index3& index = left.index3[j]; + const SparseSlice::Index3& nindex = left.index3[j + 1]; + float* out = out_ptrs[index.m]; + float* nout = out_ptrs[nindex.m]; + const float* r1 = right_ptrs[index.k1]; + const float* r2 = right_ptrs[index.k2]; + const float* r3 = right_ptrs[index.k3]; + const float* nr1 = right_ptrs[nindex.k1]; + const float* nr2 = right_ptrs[nindex.k2]; + const float* nr3 = right_ptrs[nindex.k3]; + if (cols == 128) { + MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out); + MULADD3WAY_128(nl1, nl2, nl3, nr1, nr2, nr3, nout); + } else { + for (int n = 0; n < cols / kNumOperands; ++n) { + MULADD3WAY(l1, l2, l3, r1, r2, r3, out); + MULADD3WAY(nl1, nl2, nl3, nr1, nr2, nr3, nout); + } + for (int k = 0; k < cols_mod; ++k) { + SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out); + SCALAR_MULADD3WAY(nsl1, nsl2, nsl3, nr1, nr2, nr3, nout); + } + } + } + if (j < end3) { + const float* sl1 = data3++; + LOAD_SCALAR(sl1, l1); + const float* sl2 = data3++; + LOAD_SCALAR(sl2, l2); + const float* sl3 = data3++; + LOAD_SCALAR(sl3, l3); + const SparseSlice::Index3& index = left.index3[j]; + float* out = out_ptrs[index.m]; + const float* r1 = right_ptrs[index.k1]; + const float* r2 = right_ptrs[index.k2]; + const float* r3 = right_ptrs[index.k3]; + if (cols == 128) { + MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out); + } else { + for (int n = 0; n < cols / kNumOperands; ++n) { + MULADD3WAY(l1, l2, l3, r1, r2, r3, out); + } + for (int k = 0; k < cols_mod; ++k) { + SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out); + } + } + } + begin3 = end3; + int end = left.index_offset[i]; + for (int j = begin; j < end; ++j) { + const float* sl = data++; + LOAD_SCALAR(sl, l); + const SparseSlice::Index& index = left.index[j]; + const float* r = right_ptrs[index.k]; + float* out = out_ptrs[index.m]; + for (int n = 0; n < cols / kNumOperands; ++n) { + MULADD(l, r, out); + } + for (int k = 0; k < cols_mod; ++k) { + SCALAR_MULADD(sl, r, out); + } + } + k_offset += left.block_size; + begin = end; } } - return next->startn == next->endn; } +#undef SCALAR_MULADD +#undef SCALAR_MULADD3WAY +#undef LOAD +#undef STORE +#undef LOAD_SCALAR +#undef FMA +#undef MULADD +#undef MULADD3WAY +#undef MULADD3WAY_16 +#undef MULADD3WAY_32 +#undef MULADD3WAY_128 + +} // namespace + class SparseMatMulOp : public OpKernel { public: explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -80,7 +436,6 @@ class SparseMatMulOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& a = ctx->input(0); const Tensor& b = ctx->input(1); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), errors::InvalidArgument("a is not a matrix")); OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), @@ -115,10 +470,11 @@ class SparseMatMulOp : public OpKernel { left.contract(right_mat, dim_pair); return; } - typedef Eigen::Tensor Matrix; std::unique_ptr right_tr_mat; std::unique_ptr::ConstMatrix> right_tr_map; if (transpose_b_) { + // TODO(agarwal): avoid transposing the matrix here and directly handle + // transpose in CreateDenseSlices. right_tr_mat.reset(new Matrix(k, n)); Eigen::array perm({1, 0}); right_tr_mat->device(ctx->template eigen_device()) = @@ -129,56 +485,75 @@ class SparseMatMulOp : public OpKernel { TTypes::ConstMatrix& right = transpose_b_ ? *right_tr_map : right_mat; - const bool transpose_a = transpose_a_; - - typedef Eigen::TensorMap, - Eigen::Unaligned> TensorMap; - typedef Eigen::TensorMap, - Eigen::Unaligned> ConstTensorMap; - typedef Eigen::DSizes DSizes; - const int Bm = 16; - const int Bk = 16; - const int Bn = 1024; - - auto work_shard = [m, n, k, transpose_a, Bm, Bk, Bn, &left, &right, &out]( - int64 start64, int64 end64) { - const int start = static_cast(start64); - const int end = static_cast(end64); - Block curr(start, std::min(start + Bm, end), 0, std::min(Bk, k), 0, - std::min(Bn, n)); - Block next(curr); - bool done = false; - for (int i = start; i < end; ++i) { - out.chip<0>(i).setZero(); - } - while (true) { - done = NextBlock(Bm, Bk, Bn, start, end, k, n, curr, &next); - - PrefetchBlockT1(right, curr.startk, curr.endk, curr.startn, curr.endn); - - // Process current block - for (int i = curr.startm; i < curr.endm; ++i) { - PrefetchBlockNTA(left, i, i + 1, curr.startk, curr.endk); - PrefetchBlockNTA(out, i, i + 1, curr.startn, curr.endn); - DSizes out_slice_shape(curr.endn - curr.startn); - TensorMap out_i(&out(i, curr.startn), out_slice_shape); - for (int j = curr.startk; j < curr.endk; ++j) { - const float l = transpose_a ? left(j, i) : left(i, j); - if (l == 0) continue; - ConstTensorMap right_j(&right(j, curr.startn), out_slice_shape); - out_i += right_j * l; - } - } - if (done) break; - curr = next; - } - }; - auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); - Shard(worker_threads.num_threads, worker_threads.workers, m, 2 * k * n, - work_shard); + SparseMatMul(left, right, transpose_a_, + ctx->device()->tensorflow_cpu_worker_threads(), &out); } private: + // Perform matrix multiplication of "left" and "right", and store the result + // in *"ouptut". + static inline void SparseMatMul( + const ConstMatrixMap& left, const ConstMatrixMap& right, + bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, + MatrixMap* output); + + // Computes multiplication of left and num_cols columns of right, and stores + // the output block in *"output" at offsets "output_row_offset" and + // "output_col_offset". If assign is true, assigns the value to that block, + // else adds the values to the existing values. + static inline void ComputeOutputBlock(const std::vector& left, + const ConstMatrixMap& right, + const int num_cols, + int output_row_offset, + int output_col_offset, bool assign, + MatrixMap* output); + + // Encodes "mat" using a sparse representation and stores that in + // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and + // "slice_num_cols", each grid element is converted into a SparseSlice and + // stored in mat_slices. "slice_block_size" is used to perform futher column + // blocking of each slice. + static inline BlockingCounter* CreateSparseSlices( + const ConstMatrixMap& mat, bool transpose, int slice_num_rows, + int slice_block_size, int slice_num_cols, + std::vector>* mat_slices, + const DeviceBase::CpuWorkerThreads* thread_pool); + + // This function chops "mat" along column dimension into pieces with at most N + // columns, and concatenates the pieces one after the other in "buffer". It + // returns the list of the pieces in "slices". It returns a BlockingCounter + // which should be used to wait for the shuffle operations to complete. + static inline BlockingCounter* CreateDenseSlices( + const ConstMatrixMap& mat, int row_start, int num_rows, int col_start, + int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, + Matrix* buffer, std::vector* slices); + + // Helper function for CreateDenseSlices to move the data around. It returns a + // BlockingCounter which should be used to wait for the shuffle operations to + // complete. + static inline BlockingCounter* ShuffleMatrix( + const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows, + int slice_col_start, int slice_num_cols, const int N, + const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer); + + // Helper function for CreateDenseSlices to create slices. + static inline void SliceMatrix(const Matrix& mat, const int num_rows, + const int num_slices, + std::vector* slices); + + // Heuristics to compute various block sizes. + // KR, NR: block sizes for "right". We run blocking iterations that operate on + // matrices with at most this size. + // KL: grid size along the column dimension used while encoding left. + // IB, JB: number of left and right slices to multiply together. This is used + // for ordering different ComputeBlockOutput operations inside each blocking + // iteration so as to potentially reduce the working set size. + static inline void ComputeBlockSizes(const ConstMatrixMap& left, + const ConstMatrixMap& right, + bool transpose_left, int num_threads, + int* KR, int* NR, int* KL, int* JB, + int* IB); + bool transpose_a_; bool transpose_b_; bool a_is_sparse_; @@ -186,6 +561,329 @@ class SparseMatMulOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); }; +inline void SparseMatMulOp::ComputeOutputBlock( + const std::vector& left, const ConstMatrixMap& right, + const int num_cols, int output_row_offset, int output_col_offset, + bool assign, MatrixMap* output) { + const int num_rows = left[0]->num_rows; + const int rhs_num_cols = right.dimension(1); + DCHECK_LE(num_cols, rhs_num_cols); + Matrix out(num_rows, rhs_num_cols); + out.setZero(); + if (num_cols == N) { + GEPP(left, right, num_cols, &out); + } else { + GEPP<-1>(left, right, num_cols, &out); + } + if (!assign) { + const Eigen::array begin = {output_row_offset, output_col_offset}; + const Eigen::array sizes = {num_rows, num_cols}; + if (num_cols == rhs_num_cols) { + output->slice(begin, sizes) += out; + } else { + static const Eigen::array zero = {0, 0}; + output->slice(begin, sizes) += out.slice(zero, sizes); + } + } else { + // output->slice(begin, sizes) = out.slice(zero, sizes), implemented + // using memcpy. + for (int i = 0; i < num_rows; ++i) { + memcpy(&(*output)(output_row_offset + i, output_col_offset), &out(i, 0), + num_cols * sizeof(float)); + } + } +} + +inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( + const ConstMatrixMap& mat, bool transpose, int slice_num_rows, + int slice_block_size, int slice_num_cols, + std::vector>* mat_slices, + const DeviceBase::CpuWorkerThreads* thread_pool) { + const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0); + const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1); + const int num_slices_dim0 = + std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows); + const int num_slices_dim1 = + std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols); + mat_slices->resize(num_slices_dim0); + BlockingCounter* counter = + new BlockingCounter(num_slices_dim0 * num_slices_dim1); + auto work = [counter, transpose](SparseSlice* sparse_slice, + ConstMatrixMap* slice, int col_offset) { + if (transpose) { + sparse_slice->Initialize(*slice, col_offset); + } else { + sparse_slice->Initialize(*slice, col_offset); + } + delete slice; + counter->DecrementCount(); + }; + for (int i = 0; i < num_slices_dim0; ++i) { + (*mat_slices)[i].resize(num_slices_dim1); + int num_rows = + std::min(slice_num_rows, mat_num_rows - i * slice_num_rows); + for (int j = 0; j < num_slices_dim1; ++j) { + int num_cols = + std::min(slice_num_cols, mat_num_cols - j * slice_num_cols); + ConstMatrixMap* slice = nullptr; + if (transpose) { + slice = + new ConstMatrixMap(&mat(0, i * slice_num_rows), mat.dimensions()); + } else { + DSizes d(num_rows, mat_num_cols); + slice = new ConstMatrixMap(&mat(i * slice_num_rows, 0), d); + } + SparseSlice* sparse_slice = + new SparseSlice(num_rows, num_cols, slice_block_size); + (*mat_slices)[i][j] = sparse_slice; + thread_pool->workers->Schedule( + std::bind(work, sparse_slice, slice, slice_num_cols * j)); + } + } + return counter; +} + +inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( + const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows, + int slice_col_start, int slice_num_cols, const int N, + const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer) { + int num_threads = std::min(thread_pool->num_threads, 16); + BlockingCounter* counter = new BlockingCounter(num_threads); + DCHECK_EQ(N, buffer->dimension(1)); + auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start, + slice_num_cols, N, buffer, counter](int s, int e) { + const int row_start = s % slice_num_rows + slice_row_start; + const int col_start = s / slice_num_rows * N + slice_col_start; + float* out_start = &(*buffer)(s, 0); + const float* input_start = &mat(row_start, col_start); + const float* input_end = &mat(slice_row_start + slice_num_rows - 1, + slice_col_start + slice_num_cols - 1); + const int mat_num_cols = mat.dimension(1); + const int row_slice_size = slice_num_rows * mat_num_cols; + + const int aligned_end = slice_num_cols / N * slice_num_rows; + const int e1 = std::min(e, aligned_end); + while (s < e1) { + memcpy(out_start, input_start, N * sizeof(float)); + out_start += N; + input_start += mat_num_cols; + if (input_start > input_end) { + input_start = input_start - row_slice_size + N; + } + ++s; + } + int s1 = std::max(s, aligned_end); + const int copy_num_cols = slice_num_cols % N; + while (s1 < e) { + memcpy(out_start, input_start, copy_num_cols * sizeof(float)); + out_start += N; + input_start += mat_num_cols; + ++s1; + } + if (counter) counter->DecrementCount(); + }; + + int start = 0; + int end = 0; + int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows; + DCHECK_LE(num_out_rows, buffer->dimension(0)); + for (int i = std::max(1, num_threads); i > 0; --i) { + end = start + num_out_rows / i; + thread_pool->workers->Schedule(std::bind(shuffle_work, start, end)); + num_out_rows -= (end - start); + start = end; + } + return counter; +} + +inline void SparseMatMulOp::SliceMatrix(const Matrix& mat, const int num_rows, + const int num_slices, + std::vector* slices) { + slices->resize(num_slices); + DSizes d(num_rows, mat.dimension(1)); + DCHECK_LE(num_rows * num_slices, mat.dimension(0)); + for (int i = 0; i < num_slices; ++i) { + (*slices)[i] = new ConstMatrixMap(&mat(i * num_rows, 0), d); + } +} + +inline BlockingCounter* SparseMatMulOp::CreateDenseSlices( + const ConstMatrixMap& mat, int row_start, int num_rows, int col_start, + int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, + Matrix* buffer, std::vector* slices) { + BlockingCounter* shuffle_counter = ShuffleMatrix( + mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer); + const int num_slices = (num_cols + N - 1) / N; + SliceMatrix(*buffer, num_rows, num_slices, slices); + return shuffle_counter; +} + +inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left, + const ConstMatrixMap& right, + bool transpose_left, + int num_threads, int* KR, int* NR, + int* KL, int* JB, int* IB) { + // Heuristics for calculating block sizes + // Assume two hyperthreads per core. + const int est_num_cores = std::max(1, (num_threads + 1) / 2); + // Use block of rhs with at most 128K floats per core. + const int mem = est_num_cores * 128 * 1024; + *KR = std::min(static_cast(right.dimension(0)), mem / 256); + *NR = right.dimension(1); + if (*KR * *NR > mem) { + // 4096 may be enough to ammortize the cost of writes. + *KR = std::min(*KR, 4096); + } + // Use sizes that are multiples of K and 256. + *KR = std::max(1, *KR / K) * K; + *NR = std::max(1, *NR / 256) * 256; + if (*KR * *NR > mem) { + *NR = mem / *KR; + } + *NR = std::max(1, *NR / 256) * 256; + + const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); + const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); + for (*KL = 1024; *KL > K; *KL /= 2) { + if (*KR % *KL == 0 && + std::max(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) { + break; + } + } + DCHECK_EQ(*KL % K, 0); + DCHECK_GE(*KR, *KL); + if (*KR < right.dimension(0)) { + CHECK_EQ(*KR % *KL, 0); + } + + *JB = std::max(1, static_cast(sqrt(num_threads) / 2.0)); + *IB = 8 * *JB; + DCHECK_EQ(N * sizeof(float) % 64, 0); +} + +// Here is a an overview of the SparseMatMul code. Note that we assume that the +// left matrix is sparse. +// +// The matrix "left" is divided into a grid with blocksize of (M, KL). Each +// block is encoded as a SparseSlice. These grid elements are stored as +// std::vector>. Each element of the outer vector +// represents M rows of the left matrix. Lets call these elements l_i and lets +// call each element of the inner vector L_mk. +// +// The matrix "right" is divided into a grid with block size KR * NR. Lets +// denote the blocks on the right as R_kn. Note that we ensure that KL divides +// KR so that for each element R_kn, we don't need to multiply it with any +// partial L_mk blocks. +// +// We then multiply each right side block R_kn with the full "left" matrix and +// update the output. These iterations are run sequentially since R_kn are +// packed into the same underlying temporary buffer. +// +// In each iteration we do the following: +// 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N +// (=128) columns and then concatenating these slices into a buffer. This is +// done so that each slice r_j of R_kn is stored contiguously in memory. Note +// that if R_kj has dimensions (KR, NR), we create NR / N slices, and the +// buffer has dimensions (KR * NR / N, N) (assuming N divides NR). +// 2. For each (l_i, r_j), we compute the inner product using the GEPP function +// and update the output block o_ij. These calls are further blocked to +// reduce the working set size. In each iteration we take IB elements from +// {l_i} and JB elements from {r_j} and compute the IB * JB inner products. +inline void SparseMatMulOp::SparseMatMul( + const ConstMatrixMap& left, const ConstMatrixMap& right, + bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, + MatrixMap* output) { + const int num_threads = thread_pool->num_threads; + int KR, NR, KL, JB, IB; + ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL, + &JB, &IB); + + // Slice the left matrix + std::vector> left_slices; + std::unique_ptr sparse_slice_counter; + sparse_slice_counter.reset( + CreateSparseSlices(ConstMatrixMap(left.data(), left.dimensions()), + transpose_left, M, K, KL, &left_slices, thread_pool)); + const int num_left_slices = left_slices.size(); + + const int right_dim0 = right.dimension(0); + const int right_dim1 = right.dimension(1); + // Allocate buffer for storing slices of right matrix. + // Note buffer needs enough space to hold atmost a KR * NR matrix since that + // is the block size per iteration. + const int buffer_num_rows = + std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N; + Matrix buffer(buffer_num_rows, N); + std::vector right_slices; + + std::vector block_left_slices; + std::vector> tasks; + // Number of blocks based on block sizes of KR * NR. + const int num_k_blocks = (right_dim0 + KR - 1) / KR; + const int num_n_blocks = (right_dim1 + NR - 1) / NR; + std::unique_ptr dense_slice_counter; + + for (int nb = 0; nb < num_n_blocks; ++nb) { + const int right_num_cols = + std::min(NR, static_cast(right_dim1 - NR * nb)); + for (int kb = 0; kb < num_k_blocks; ++kb) { + const int right_num_rows = + std::min(KR, static_cast(right_dim0 - KR * kb)); + dense_slice_counter.reset(CreateDenseSlices( + right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool, + &buffer, &right_slices)); + const int num_right_slices = right_slices.size(); + tasks.reserve(num_left_slices * num_right_slices); + for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) { + for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) { + for (int j_inner = j_outer; + j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) { + const int num_cols = std::min(N, right_num_cols - N * j_inner); + for (int i_inner = i_outer; + i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) { + // Figure out which left slices to use. + block_left_slices.clear(); + int begin = kb * KR / KL; + int end = std::min((kb + 1) * KR / KL, + (right.dimension(0) + KL - 1) / KL); + DCHECK_LT(begin, end); + block_left_slices.insert(block_left_slices.begin(), + left_slices[i_inner].begin() + begin, + left_slices[i_inner].begin() + end); + tasks.push_back(std::bind( + &SparseMatMulOp::ComputeOutputBlock, block_left_slices, + std::ref(*right_slices[j_inner]), num_cols, M * i_inner, + N * j_inner + nb * NR, kb == 0, output)); + } + } + } + } + if (sparse_slice_counter) { + sparse_slice_counter->Wait(); + sparse_slice_counter.reset(nullptr); + } + if (dense_slice_counter) { + dense_slice_counter->Wait(); + dense_slice_counter.reset(nullptr); + } + BlockingCounter bc(tasks.size()); + for (const auto& t : tasks) { + thread_pool->workers->Schedule([&bc, &t]() { + t(); + bc.DecrementCount(); + }); + } + bc.Wait(); + tasks.clear(); + gtl::STLDeleteElements(&right_slices); + right_slices.clear(); + } + } + for (auto& left_slice : left_slices) { + gtl::STLDeleteElements(&left_slice); + } +} + REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU), SparseMatMulOp); diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc index 883d0d12248..061a5be4ce7 100644 --- a/tensorflow/core/kernels/sparse_matmul_op_test.cc +++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc @@ -20,6 +20,8 @@ void Sparsify(Tensor* t, float sparsity) { for (int64 i = 0; i < N; ++i) { if (rnd.Uniform(K) < sparsity * K) { flat(i) = 0; + } else if (flat(i) == 0) { + flat(i) = 0.1; } } } @@ -86,19 +88,29 @@ static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_a, return g; } -#define BM_SPARSE(M, K, N, S) \ - static void BM_Sparse##_##M##_##K##_##N##_##S(int iters) { \ - testing::ItemsProcessed(static_cast(iters) * M * K * N * 2); \ - std::string label = strings::Printf("%d_%d_%d_%0.2f", M, K, N, S / 100.0); \ - testing::SetLabel(label); \ - test::Benchmark("cpu", SparseMatMul(M, N, K, S / 100.0, false, false)) \ - .Run(iters); \ - } \ +#define BM_SPARSE(M, K, N, S) \ + static void BM_Sparse##_##M##_##K##_##N##_##S(int iters) { \ + testing::StopTiming(); \ + testing::ItemsProcessed(static_cast(iters) * M * K * N * 2); \ + std::string label; \ + if (S == 0) { \ + label = strings::Printf("%d*%d*%d_Eigen", M, K, N); \ + } else { \ + label = strings::Printf("%d*%d*%d_sparsity:%0.2f", M, K, N, S / 100.0); \ + } \ + testing::SetLabel(label); \ + testing::UseRealTime(); \ + auto g = SparseMatMul(M, N, K, S / 100.0, false, false); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ + } \ BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S); BM_SPARSE(2048, 2048, 2048, 0); BM_SPARSE(2048, 2048, 2048, 1); +BM_SPARSE(2048, 2048, 2048, 50); BM_SPARSE(2048, 2048, 2048, 85); +BM_SPARSE(2048, 2048, 2048, 99); BM_SPARSE(1024, 1024, 1024, 0); BM_SPARSE(1024, 1024, 1024, 1); @@ -107,28 +119,34 @@ BM_SPARSE(1024, 1024, 1024, 85); BM_SPARSE(256, 256, 256, 1); BM_SPARSE(512, 512, 512, 1); -#define BM_SPARSE_MULTI(M, K, N, S1, S2) \ - static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2(int iters) { \ - testing::ItemsProcessed(static_cast(iters) * M * K * N * 2 * 3); \ - std::string label = strings::Printf("%d_%d_%d_%0.2f_%0.2f", M, K, N, \ - S1 / 100.0, S2 / 100.0); \ - testing::SetLabel(label); \ - test::Benchmark("cpu", MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0)) \ - .Run(iters); \ - } \ +#define BM_SPARSE_MULTI(M, K, N, S1, S2) \ + static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2(int iters) { \ + testing::StopTiming(); \ + testing::ItemsProcessed(static_cast(iters) * M * K * N * 2 * 3); \ + std::string label = strings::Printf("%d_%d_%d_%0.2f_%0.2f", M, K, N, \ + S1 / 100.0, S2 / 100.0); \ + testing::SetLabel(label); \ + testing::UseRealTime(); \ + auto g = MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ + } \ BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2); -BM_SPARSE_MULTI(512, 2140, 4096, 0, 82); -BM_SPARSE_MULTI(512, 4096, 2048, 83, 83); +BM_SPARSE_MULTI(1024, 2140, 4096, 0, 82); +BM_SPARSE_MULTI(1024, 4096, 2048, 83, 83); #define BM_SPARSE_TR(M, K, N, S, TA, TB) \ static void BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB(int iters) { \ + testing::StopTiming(); \ testing::ItemsProcessed(static_cast(iters) * M * K * N * 2); \ std::string label = \ strings::Printf("%d_%d_%d_%d_%d_%0.2f", M, K, N, TA, TB, S / 100.0); \ testing::SetLabel(label); \ - test::Benchmark("cpu", SparseMatMul(M, N, K, S / 100.0, TA, TB)) \ - .Run(iters); \ + testing::UseRealTime(); \ + auto g = SparseMatMul(M, N, K, S / 100.0, TA, TB); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ } \ BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB); diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc index 4df6bdd510b..457342899c1 100644 --- a/tensorflow/core/kernels/string_to_hash_bucket_op.cc +++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc @@ -25,7 +25,7 @@ class StringToHashBucketOp : public OpKernel { &output_tensor)); auto output_flat = output_tensor->flat(); - for (std::size_t i = 0; i < input_flat.size(); ++i) { + for (size_t i = 0; i < input_flat.size(); ++i) { const uint64 input_hash = Hash64(input_flat(i)); const uint64 bucket_id = input_hash % num_buckets_; // The number of buckets is always in the positive range of int64 so is diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc index 0f1072ffaa1..9531134bb75 100644 --- a/tensorflow/core/lib/core/command_line_flags.cc +++ b/tensorflow/core/lib/core/command_line_flags.cc @@ -17,6 +17,12 @@ bool StringToValue(const string& content, int* value) { return str_util::NumericParse32(content, value); } +template <> +bool StringToValue(const string& content, string* value) { + *value = content; + return true; +} + // Parse a single argument by linearly searching through the command table. // The input format is: --argument=value. // Return OK if the argument is used. It store the extracted value into the @@ -27,7 +33,7 @@ bool StringToValue(const string& content, int* value) { template Status ParseArgument(const string& argument) { for (auto& command : - internal::CommandLineFlagRegistry::Instance()->commands) { + internal::CommandLineFlagRegistry::Instance()->commands) { string prefix = strings::StrCat("--", command.name, "="); if (tensorflow::StringPiece(argument).starts_with(prefix)) { string content = argument.substr(prefix.length()); @@ -62,6 +68,7 @@ Status ParseArgument(const string& argument) { return Status(error::NOT_FOUND, strings::StrCat("Unknown command: ", argument)); } + } // namespace Status ParseCommandLineFlags(int* argc, char* argv[]) { @@ -81,6 +88,11 @@ Status ParseCommandLineFlags(int* argc, char* argv[]) { if (s.ok()) { continue; } + // Search string commands. + s = ParseArgument(argv[index]); + if (s.ok()) { + continue; + } if (s.code() != error::NOT_FOUND) { return s; } diff --git a/tensorflow/core/lib/core/command_line_flags.h b/tensorflow/core/lib/core/command_line_flags.h index f1a94c11f9c..8c7c00cd713 100644 --- a/tensorflow/core/lib/core/command_line_flags.h +++ b/tensorflow/core/lib/core/command_line_flags.h @@ -43,11 +43,14 @@ struct CommandLineFlagRegister { } // namespace internal #define TF_DEFINE_int32(name, default_value, text) \ - TF_DEFINE_variable(int32, name, default_value, text); + TF_DEFINE_variable(tensorflow::int32, name, default_value, text); #define TF_DEFINE_bool(name, default_value, text) \ TF_DEFINE_variable(bool, name, default_value, text); +#define TF_DEFINE_string(name, default_value, text) \ + TF_DEFINE_variable(string, name, default_value, text); + // Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv. // Returned the number of unused arguments in *argc. // Return error Status if the parsing encounters errors. diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 321d9c02768..5356301d707 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -859,10 +859,10 @@ REGISTER_OP("ListDiff") .Output("idx: int32") .Attr("T: type") .Doc(R"doc( -Computes the difference between two lists of numbers. +Computes the difference between two lists of numbers or strings. Given a list `x` and a list `y`, this operation returns a list `out` that -represents all numbers that are in `x` but not in `y`. The returned list `out` +represents all values that are in `x` but not in `y`. The returned list `out` is sorted in the same order that the numbers appear in `x` (duplicates are preserved). This operation also returns a list `idx` that represents the position of each `out` element in `x`. In other words: diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index a9b940295e7..67bf28bceee 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -35,7 +35,14 @@ REGISTER_OP("MatrixInverse") .Output("output: T") .Attr("T: {float, double}") .Doc(R"doc( -Calculates the inverse of a square invertible matrix. Checks for invertibility. +Calculates the inverse of a square invertible matrix. + +The op uses the Cholesky decomposition if the matrix is symmetric positive +definite and LU decomposition with partial pivoting otherwise. + +If the matrix is not invertible there is no guarantee what the op does. It +may detect the condition and raise an exception or it may simply return a +garbage result. input: Shape is `[M, M]`. output: Shape is `[M, M]` containing the matrix inverse of the input. @@ -47,12 +54,19 @@ REGISTER_OP("BatchMatrixInverse") .Output("output: T") .Attr("T: {float, double}") .Doc(R"doc( -Calculates the inverse of square invertible matrices. Checks for invertibility. +Calculates the inverse of square invertible matrices. The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form square matrices. The output is a tensor of the same shape as the input containing the inverse for all input submatrices `[..., :, :]`. +The op uses the Cholesky decomposition if the matrices are symmetric positive +definite and LU decomposition with partial pivoting otherwise. + +If a matrix is not invertible there is no guarantee what the op does. It +may detect the condition and raise an exception or it may simply return a +garbage result. + input: Shape is `[..., M, M]`. output: Shape is `[..., M, M]`. T: The type of values in the input and output. diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h index fe1846319e9..1285c729b2c 100644 --- a/tensorflow/core/public/tensor_c_api.h +++ b/tensorflow/core/public/tensor_c_api.h @@ -186,8 +186,8 @@ extern void TF_SetTarget(TF_SessionOptions* options, const char* target); // config should be a serialized brain.ConfigProto proto. // If config was not parsed successfully as a ConfigProto, record the // error information in *status. -extern void TF_SetConfig(TF_SessionOptions* options, const char* config, - size_t config_len, TF_Status* status); +extern void TF_SetConfig(TF_SessionOptions* options, const void* proto, + size_t proto_len, TF_Status* status); // Destroy an options object. extern void TF_DeleteSessionOptions(TF_SessionOptions*); diff --git a/tensorflow/examples/android/jni/jni_utils.cc b/tensorflow/examples/android/jni/jni_utils.cc index 3fffc19cb62..76ca6d9da94 100644 --- a/tensorflow/examples/android/jni/jni_utils.cc +++ b/tensorflow/examples/android/jni/jni_utils.cc @@ -10,11 +10,11 @@ #include #include +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/message_lite.h" #include "tensorflow/core/platform/logging.h" -#include "google/protobuf/src/google/protobuf/io/zero_copy_stream_impl.h" -#include "google/protobuf/src/google/protobuf/io/zero_copy_stream_impl_lite.h" -#include "google/protobuf/src/google/protobuf/io/coded_stream.h" -#include "google/protobuf/src/google/protobuf/message_lite.h" static const char* const ASSET_PREFIX = "file:///android_asset/"; diff --git a/tensorflow/examples/label_image/BUILD b/tensorflow/examples/label_image/BUILD new file mode 100644 index 00000000000..a2704deb3a0 --- /dev/null +++ b/tensorflow/examples/label_image/BUILD @@ -0,0 +1,30 @@ +# Description: +# Tensorflow C++ inference example for labeling images. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_binary( + name = "label_image", + srcs = ["main.cc"], + linkopts = ["-lm"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/core:tensorflow", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "bin/**", + "gen/**", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/examples/label_image/README.md b/tensorflow/examples/label_image/README.md new file mode 100644 index 00000000000..fbad61a22fa --- /dev/null +++ b/tensorflow/examples/label_image/README.md @@ -0,0 +1,49 @@ +# Tensorflow C++ Image Recognition Demo + +This example shows how you can load a pre-trained TensorFlow network and use it +to recognize objects in images. + +## Description + +This demo uses a Google Inception model to classify image files that are passed +in on the command line. See +[`googlenet_labels.txt`](data/googlenet_labels.txt) +for the possible classifications, which are the 1,000 categories used in the +Imagenet competition. + +## To build/install/run + +As long as you've managed to build the main TensorFlow framework, you should +have everything you need to run this example installed already. + +To build it, run this command: + +```bash +$ bazel build tensorflow/examples/label_image/... +``` + +That should build a binary executable that you can then run like this: + +```bash +$ bazel-bin/tensorflow/examples/label_image/label_image +``` + +This uses the default example image that ships with the framework, and should +output something similar to this: + +``` +I tensorflow/examples/label_image/main.cc:200] military uniform (866): 0.902268 +I tensorflow/examples/label_image/main.cc:200] bow tie (817): 0.05407 +I tensorflow/examples/label_image/main.cc:200] suit (794): 0.0113195 +I tensorflow/examples/label_image/main.cc:200] bulletproof vest (833): 0.0100269 +I tensorflow/examples/label_image/main.cc:200] bearskin (849): 0.00649746 +``` +In this case, we're using the default image of Admiral Grace Hopper, and you can +see the network correctly spots she's wearing a military uniform, with a high +score of 0.9. + +Next, try it out on your own images by supplying the --image= argument, e.g. + +```bash +$ bazel-bin/tensorflow/examples/label_image/label_image --image=my_image.png +``` diff --git a/tensorflow/examples/label_image/data/googlenet_labels.txt b/tensorflow/examples/label_image/data/googlenet_labels.txt new file mode 100644 index 00000000000..0ac5a169d92 --- /dev/null +++ b/tensorflow/examples/label_image/data/googlenet_labels.txt @@ -0,0 +1,1001 @@ +dummy +kit fox +English setter +Siberian husky +Australian terrier +English springer +grey whale +lesser panda +Egyptian cat +ibex +Persian cat +cougar +gazelle +porcupine +sea lion +malamute +badger +Great Dane +Walker hound +Welsh springer spaniel +whippet +Scottish deerhound +killer whale +mink +African elephant +Weimaraner +soft-coated wheaten terrier +Dandie Dinmont +red wolf +Old English sheepdog +jaguar +otterhound +bloodhound +Airedale +hyena +meerkat +giant schnauzer +titi +three-toed sloth +sorrel +black-footed ferret +dalmatian +black-and-tan coonhound +papillon +skunk +Staffordshire bullterrier +Mexican hairless +Bouvier des Flandres +weasel +miniature poodle +Cardigan +malinois +bighorn +fox squirrel +colobus +tiger cat +Lhasa +impala +coyote +Yorkshire terrier +Newfoundland +brown bear +red fox +Norwegian elkhound +Rottweiler +hartebeest +Saluki +grey fox +schipperke +Pekinese +Brabancon griffon +West Highland white terrier +Sealyham terrier +guenon +mongoose +indri +tiger +Irish wolfhound +wild boar +EntleBucher +zebra +ram +French bulldog +orangutan +basenji +leopard +Bernese mountain dog +Maltese dog +Norfolk terrier +toy terrier +vizsla +cairn +squirrel monkey +groenendael +clumber +Siamese cat +chimpanzee +komondor +Afghan hound +Japanese spaniel +proboscis monkey +guinea pig +white wolf +ice bear +gorilla +borzoi +toy poodle +Kerry blue terrier +ox +Scotch terrier +Tibetan mastiff +spider monkey +Doberman +Boston bull +Greater Swiss Mountain dog +Appenzeller +Shih-Tzu +Irish water spaniel +Pomeranian +Bedlington terrier +warthog +Arabian camel +siamang +miniature schnauzer +collie +golden retriever +Irish terrier +affenpinscher +Border collie +hare +boxer +silky terrier +beagle +Leonberg +German short-haired pointer +patas +dhole +baboon +macaque +Chesapeake Bay retriever +bull mastiff +kuvasz +capuchin +pug +curly-coated retriever +Norwich terrier +flat-coated retriever +hog +keeshond +Eskimo dog +Brittany spaniel +standard poodle +Lakeland terrier +snow leopard +Gordon setter +dingo +standard schnauzer +hamster +Tibetan terrier +Arctic fox +wire-haired fox terrier +basset +water buffalo +American black bear +Angora +bison +howler monkey +hippopotamus +chow +giant panda +American Staffordshire terrier +Shetland sheepdog +Great Pyrenees +Chihuahua +tabby +marmoset +Labrador retriever +Saint Bernard +armadillo +Samoyed +bluetick +redbone +polecat +marmot +kelpie +gibbon +llama +miniature pinscher +wood rabbit +Italian greyhound +lion +cocker spaniel +Irish setter +dugong +Indian elephant +beaver +Sussex spaniel +Pembroke +Blenheim spaniel +Madagascar cat +Rhodesian ridgeback +lynx +African hunting dog +langur +Ibizan hound +timber wolf +cheetah +English foxhound +briard +sloth bear +Border terrier +German shepherd +otter +koala +tusker +echidna +wallaby +platypus +wombat +revolver +umbrella +schooner +soccer ball +accordion +ant +starfish +chambered nautilus +grand piano +laptop +strawberry +airliner +warplane +airship +balloon +space shuttle +fireboat +gondola +speedboat +lifeboat +canoe +yawl +catamaran +trimaran +container ship +liner +pirate +aircraft carrier +submarine +wreck +half track +tank +missile +bobsled +dogsled +bicycle-built-for-two +mountain bike +freight car +passenger car +barrow +shopping cart +motor scooter +forklift +electric locomotive +steam locomotive +amphibian +ambulance +beach wagon +cab +convertible +jeep +limousine +minivan +Model T +racer +sports car +go-kart +golfcart +moped +snowplow +fire engine +garbage truck +pickup +tow truck +trailer truck +moving van +police van +recreational vehicle +streetcar +snowmobile +tractor +mobile home +tricycle +unicycle +horse cart +jinrikisha +oxcart +bassinet +cradle +crib +four-poster +bookcase +china cabinet +medicine chest +chiffonier +table lamp +file +park bench +barber chair +throne +folding chair +rocking chair +studio couch +toilet seat +desk +pool table +dining table +entertainment center +wardrobe +Granny Smith +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +acorn +hip +ear +rapeseed +corn +buckeye +organ +upright +chime +drum +gong +maraca +marimba +steel drum +banjo +cello +violin +harp +acoustic guitar +electric guitar +cornet +French horn +trombone +harmonica +ocarina +panpipe +bassoon +oboe +sax +flute +daisy +yellow lady's slipper +cliff +valley +alp +volcano +promontory +sandbar +coral reef +lakeside +seashore +geyser +hatchet +cleaver +letter opener +plane +power drill +lawn mower +hammer +corkscrew +can opener +plunger +screwdriver +shovel +plow +chain saw +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +white stork +black stork +spoonbill +flamingo +American egret +little blue heron +bittern +crane +limpkin +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +European gallinule +pelican +king penguin +albatross +great white shark +tiger shark +hammerhead +electric ray +stingray +barracouta +coho +tench +goldfish +eel +rock beauty +anemone fish +lionfish +puffer +sturgeon +gar +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +triceratops +African crocodile +American alligator +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +whistle +wing +paintbrush +hand blower +oxygen mask +snorkel +loudspeaker +microphone +screen +mouse +electric fan +oil filter +strainer +space heater +stove +guillotine +barometer +rule +odometer +scale +analog clock +digital clock +wall clock +hourglass +sundial +parking meter +stopwatch +digital watch +stethoscope +syringe +magnetic compass +binoculars +projector +sunglasses +loupe +radio telescope +bow +cannon [ground] +assault rifle +rifle +projectile +computer keyboard +typewriter keyboard +crane +lighter +abacus +cash machine +slide rule +desktop computer +hand-held computer +notebook +web site +harvester +thresher +printer +slot +vending machine +sewing machine +joystick +switch +hook +car wheel +paddlewheel +pinwheel +potter's wheel +gas pump +carousel +swing +reel +radiator +puck +hard disc +sunglass +pick +car mirror +solar dish +remote control +disk brake +buckle +hair slide +knot +combination lock +padlock +nail +safety pin +screw +muzzle +seat belt +ski +candle +jack-o'-lantern +spotlight +torch +neck brace +pier +tripod +maypole +mousetrap +spider web +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +isopod +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +sea urchin +sea cucumber +iron +espresso maker +microwave +Dutch oven +rotisserie +toaster +waffle iron +vacuum +dishwasher +refrigerator +washer +Crock Pot +frying pan +wok +caldron +coffeepot +teapot +spatula +altar +triumphal arch +patio +steel arch bridge +suspension bridge +viaduct +barn +greenhouse +palace +monastery +library +apiary +boathouse +church +mosque +stupa +planetarium +restaurant +cinema +home theater +lumbermill +coil +obelisk +totem pole +castle +prison +grocery store +bakery +barbershop +bookshop +butcher shop +confectionery +shoe shop +tobacco shop +toyshop +fountain +cliff dwelling +yurt +dock +brass +megalith +bannister +breakwater +dam +chainlink fence +picket fence +worm fence +stone wall +grille +sliding door +turnstile +mountain tent +scoreboard +honeycomb +plate rack +pedestal +beacon +mashed potato +bell pepper +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +cardoon +mushroom +shower curtain +jean +carton +handkerchief +sandal +ashcan +safe +plate +necklace +croquet ball +fur coat +thimble +pajama +running shoe +cocktail shaker +chest +manhole cover +modem +tub +tray +balance beam +bagel +prayer rug +kimono +hot pot +whiskey jug +knee pad +book jacket +spindle +ski mask +beer bottle +crash helmet +bottlecap +tile roof +mask +maillot +Petri dish +football helmet +bathing cap +teddy bear +holster +pop bottle +photocopier +vestment +crossword puzzle +golf ball +trifle +suit +water tower +feather boa +cloak +red wine +drumstick +shield +Christmas stocking +hoopskirt +menu +stage +bonnet +meat loaf +baseball +face powder +scabbard +sunscreen +beer glass +hen-of-the-woods +guacamole +lampshade +wool +hay +bow tie +mailbag +water jug +bucket +dishrag +soup bowl +eggnog +mortar +trench coat +paddle +chain +swab +mixing bowl +potpie +wine bottle +shoji +bulletproof vest +drilling platform +binder +cardigan +sweatshirt +pot +birdhouse +hamper +ping-pong ball +pencil box +pay-phone +consomme +apron +punching bag +backpack +groom +bearskin +pencil sharpener +broom +mosquito net +abaya +mortarboard +poncho +crutch +Polaroid camera +space bar +cup +racket +traffic light +quill +radio +dough +cuirass +military uniform +lipstick +shower cap +monitor +oscilloscope +mitten +brassiere +French loaf +vase +milk can +rugby ball +paper towel +earthstar +envelope +miniskirt +cowboy hat +trolleybus +perfume +bathtub +hotdog +coral fungus +bullet train +pillow +toilet tissue +cassette +carpenter's kit +ladle +stinkhorn +lotion +hair spray +academic gown +dome +crate +wig +burrito +pill bottle +chain mail +theater curtain +window shade +barrel +washbasin +ballpoint +basketball +bath towel +cowboy boot +gown +window screen +agaric +cellular telephone +nipple +barbell +mailbox +lab coat +fire screen +minibus +packet +maze +pole +horizontal bar +sombrero +pickelhaube +rain barrel +wallet +cassette player +comic book +piggy bank +street sign +bell cote +fountain pen +Windsor tie +volleyball +overskirt +sarong +purse +bolo tie +bib +parachute +sleeping bag +television +swimming trunks +measuring cup +espresso +pizza +breastplate +shopping basket +wooden spoon +saltshaker +chocolate sauce +ballplayer +goblet +gyromitra +stretcher +water bottle +dial telephone +soap dispenser +jersey +school bus +jigsaw puzzle +plastic bag +reflex camera +diaper +Band Aid +ice lolly +velvet +tennis ball +gasmask +doormat +Loafer +ice cream +pretzel +quilt +maillot +tape player +clog +iPod +bolete +scuba diver +pitcher +matchstick +bikini +sock +CD player +lens cap +thatch +vault +beaker +bubble +cheeseburger +parallel bars +flagpole +coffee mug +rubber eraser +stole +carbonara +dumbbell \ No newline at end of file diff --git a/tensorflow/examples/label_image/data/grace_hopper.jpg b/tensorflow/examples/label_image/data/grace_hopper.jpg new file mode 100644 index 0000000000000000000000000000000000000000..478720d6694a56b8962630e6dde8a230fc937049 GIT binary patch literal 61306 zcmb4qXH=8V6Yd)by@%dg0)*b1h;$MC&b7lU_nZx~Qmh z1sjMcHoW}5+;hL)eLwBlIXk<1-kIHb=9zz+|F!{U6C|5dwco2%Z6fn<^I=``=7X`F~;4~+Z7QS z7UqqS4G8o7_X^MlD8Uqz6ktk93Q8&}N@`kWT3Q+!S~f-|dS*^GE-p?s4h|lEm>>_Y zC?5xhkhGAfm^fS#&Mhb-FC!rjlYmS7ZxRp{6%{QtEh{Z8s{{`RkHr6X`!@!#P?I^3 z0puVNfQ$u1&I0;31@Hp^2uyyp>;K!x01!FV)q<&Mu7(lJ0Ei5H)yS{*0HPwNxazV1 zRPjjD7uGJDv3gs zle+ah3hUTWnB{*j0s1S+D=8M>22jU!;9kaXKOxXa#aeov9!=i(3Q;DH*gVHP{aZF& z)yU3d({%aD!MJwrm_mZbQwDB_+E0U(Cky--jf}eZYimhwv)xLpa0+TQC0|@zygVh@ z^l=7Q3($l{(Hc-tiZgjcG z+^T5`n;LXd1NI-#XKenBT;3$p&Ab?MFn@hd@|L>=*O043pt46ht{<*yDt~=Xh+RZa z34*(1p2=8Qi2wxgsZRDiTq7H5YM-p z$tRO^%HaI=oYsZ>XNGzGV*}?d_09!_yUw_~4Hx|}>Fmf!|3g-hR6dg9hSBqv@B_FB zdb*jl+@9o|F*hsu&Tjo_Ym3cdRj;c1b&`83rpr_Vv|09$gFX9zz9_;6-R~${O;G`@ z(6N;0;`Ez+KpMj}Ej7QPam1DUO8b#}^1iWJEj`kKKSV1t=3KYbf-Xg3z*As&hZw|> zl?xyDvT1VkILc>a`Nd{!{o`xmK?%|D0v=yP5 z4C3mOFT88rtb4;zw5{>i{p~vK6+U}g>kQK|FH`3hjfh3Z@XD8P9$ySC2$bpD7nlEu zZUQmWm5@)s7r7c8n*tGyouUhCd80k^*B#4;GtU1L*xj-I#7Fuuycy*#RdQe+X80|y zXd<;}CQ>ATqic%a<;LeckFI|}$(PFTUSq=@g-WU1u4&|xhl&bg0I;o}`3KzCADQ|GG#3suLLcHG#!8^`Keflnb*3es`MNi`=w;JJSX#`oc#!8d ze3W=>&`W>6{)^k=vJx{#WUZ^BOeO=aij~2lkF_p#|MPLrx;Oo5JEbRHbpqm;Q$Q{1 zinr2MIdQ%e+=I{5px06M(oYNUWRBjwVy{94nz#QGv}ZVZEXJwe;n8x?NO4mlsXbV- z#jiD|ycj)Ksp?wWqG?RWSnq(*kPTW_d2GFRkM3kIBOt&wqctn)x^&;jA~*8QTZd=& zCFW2oXrf-ll+vG+IO8c&^!@Y;i=I{=WU!1QIpy7kr)(;FEDeUY+MF_lcyE2;f5xz8 zDNSfxezvh^o#T~BCqXEMZLpl4@Pt$52fHhGyP(%nY_Z7o!h*(gFGUdc^^a86h3l` z8lbx&%=Crgpr^4HF!G7+BjXY$saE*OU<3i|#usOYhnDtyBa z#p+7qObPy-1@tKdc6pLTDHB7Tgjz52vIbHEuW>H(g<5zSkArLvXR`tybEQ8QwJ;(A zC!5KRC50_0Z90BWW$|@_EIgzQ)Kq`-=dXYqd5}bn0!NXfw+;qS&i@{`5|A=NI31{h z+T&mPGEZn|#O>FdxYDG$ObHHWQZKoo*=YbT{{9h|5~Z&h1?nJsk-P2_jz97Pbh;7m z^^-gdcg8W8&oaTn8Bpz{p0BS{G*M)FPA}ign%9!)=%P%@O~m!?=z}%rab(#C$Un6@ zRB{?JTSLh@->T;|f=&P(xKx*id5%$|!wtphiJI3tThP3JfJ7M4H2MgjuLt**%&#%0 z$@@b&p5j}dJs?sDiK^bLo_q;q{0=44{X`Dxc_^um!)b)N{raDoBlr_WjrPK~R1@tiO z%z#`Ne~ls;UYw7hj$A0CJo3#JU231{u!1A5(FIL+5Of4QHmM>iei@nDh&fP`#sYYz zZs%A{gYH}DmwsTk{6W4{qRZeWcr*0xsD-R`>3{!#D$jX)d266)Hd^|2PAU$DLw}~M z`E8UDX^`Yzf*E!mVOaPFkb=U>??p=)O0=vD%^^Q@7CyWQOA_Mrh?G6;87s+*IpkOM z6LF{#YnE_Ky6cgXS0^(4;zOR)ZeN%sAQER+cpZ~6gt!eexOv3)T`L92(M00Do}D~>EwwL3eFlOD&F)AJQn9~`|?QKv>%T;18|aICL(>`|&C zf*GAa=_i-$YvgE?^DKE@{sXRViTh^|xraH6+=fBMH_`0fBT=YlSM1$!1ym#7koZTR^ugw#Gh*(c?y;O|J99N17 zcu4Q!HXYbRNU^bI?l%>Wk~5esK3MD2n`rDn-FME4Vp%Ht;MXr;cF)Ua%_EjwWX%yc zlM-Z7 zdwR0f3VarbtidQ6;a4%wxs%UpZQ4?q5w^CrY6I;Inuc@hfaaeLh7>Lk&01E93gk_CQyRS#?*6S!&wMLG z^Zv%UcSc{ajK*uI-}Xyp9a&d}&`CaXZ3SCB%gN`9l@FkAM~TnfM~vPv?mB7*G`pq^GwKQ6@B!7aHyG5YHIJ7ra%1$;;( znGGso5YjVHfoBrbxFA$;>uu&oVQ*n5hsS59w+`~OC~2l{pEN6gLM1t%!;MU9v<|Nt zeZ3AMOCSWg_aV@FWC$ULE_tGbF=;VGRUX8w=O7^j7qyCl?J%2S_X)3|{kyiB`1;!Y zY)%<^KuUHSW})FmrZ_D;7jgSe(sC_+Jo#&qhxva6N2cH*A%1?60vLz3uiyvtMP0@ivD^LcK73tZXNx(v{mii!^PQps4Tl!#ph4p|qrFy(v_iDAM z<3|#Z=~e)MHNcT+6mv4YeF_H@-^F7vAdJIDK7b2rHoctj>M}9!84UFsvy?Pklb!yNGb9pp(5Sf7QpDC(z}g}< zcBhi5O3$F)OH{Y~8|6~T?oEe-5sC9ey;V1h64k(0e#$Z?oOWs7e`FL%{3=%Nnkwn) z|EOt1cZebqT{_N}8)Q31M|u*w`xm>@#zsV5z0?@GY}L%JH}OHr^zqxaF#1@-GQ^|% z;55JM8)`*P+StH^19AKRy6%6f1aY2!Zf%(mtop zkl&_>Bm`6XB<3OVTb_|(nZy9f_>RGQipcF!%`fiEe2s2CFpJxNMdjcZI><1K8JXEv ztVhhX!&h%vz+q61n*-CdYVd&>X=nRDwc+QZ%%w*5HXU2{r9jSEf^NZcvT0f!BAx3S z1@AUVeN|shGQ|L6Ne6uX-43pvG^>Kzzxp58EPv3%K_FyXF9wd*jBMBC)XL?tzk*{X z@gOTza=RF!t`HRm>{Xv@PPTAG6uD}njsvwK(gdOeb8L^JrkO1J`Xf;O3-uw1(GA2g za-Q{I(i2)(_Hhu#X9AW#_1GvkMq`^S?A$bv-fFw~rXC>wllhQB zRIkyT~em%6*96v;5`{zvV%0OTnQ-BfLi(oQdC# z?K#cLWn%TbaMlft=XtKVt`lAfp0|>JYzJv2I6v9zgPSB-;p!WrJv<(m<4XXR&!FsU zv;OtI>c6DeqE7G=acN#@==`r3xL#XtG;=Ld9+{}^v3 za_b*3f=Hj}r){RWh zb473Ep-B82bh%EP%QwqMh)&dbJYzL)g)7-0nFU9yQj!=3k1euM#jLCU_p^%Xr=fr5 z#hqQ(^rBYd^v(8Pe^_FRWla<%18C%)%IIc2*}7jPnZP^vA-(%3(%F5Y&FMIs*jNlh zy`y`d95&b^5$`>D96rRwcPTJ?r0`-NZ1wB`$g2`-+GXmhiWO?)T;)mG4UW0bG!v6N zUO_Z7Dl?sx?0WJaF8}uiIRB0z#f~es>=E;BbMr1q`yu_fH_$KUKd;%`k1M>eS*x6F z$&((A7m11%*4M5^={^;PnmdlrcX)gjE!z`h8%#(JAgqQ{cQ)O>F5pEFDI=Y+Hr0V}Bq_^T>cZ<|dqoz!)nqj=E)yKB@a`>}mfTDXCF+5NQ#U-zU zfwBDiyY&-IZAFfzlJS{|O@RlzM-5S5cy2)aF4R7Uhfqv2P1-Ix`M8}CT9jHC8w;-MfE+)2qeIyzsOYISZ9yQ{cx zCd=YH0Y!||I{0|7{nd-IDMcLVw zi^FxpYU4Op-vN*ZgTb3`RN*g`RxQd%peOR{uj|H$I>Jclj}~19i7LTm3~tULqz-3@ zhhH}T8MMGfmb@s;L*<49l-+Og919c%cFXiBcH%&HsRvHR*dVOLddodcJzWn)67a=W z{^_E!?ARzY1)!ltT#MfZusztTml*&2-@2q>1S0xXj48Ks{kxN@jpyMsNn+zw*?~X` zY9~2CLH3niAUj{H?DRtKxMq|Uxf{~w3w^DTbDboh^!e@E4_ElvHD-r%KObpKIz6R3 zBMq25yj>khA%V&eb{Zo$r^qVN=O1UzLI{DotZsjyXS~V0^CkMOn;oXb1f1NhvDI)3EAyUedQC;IKp)@uhb(Q* zMHZQ6)9@GPOAHAEnFcecA+F{v5VH_uJ}i4P=}8PcRGHwQ{mY+0Qsjl%Mzc4KUGO~V z_Ui`ml;+i&Q-hJ2sUlAA>Le+gu#q_?^&1qE0S3M&IIMa;Ls`KzyW&Fi&y?_9slrXe z!o7s_`#ZLXsj#1>cAQoPsgd|5-XB?6T~$eM-ROQBFKG*n`CT+5AMOmM6!a|?jGt3< z2yjmi`f@bx#$JG19s5Ay4m+V+0Ah_WG`4{BjrX{;i)iXzS0YZ}e5WjZ!u zUz4I@-laJ#f+iQHyFJBBl@PaG_6itb4wNtU>`F|knQ;n1Egj{VYpm@-2)&j96nhyN zQYO^zIZT{%xTt^YboFG0%4JHWwQ;hj!o!<8caX<$@UXtp%dxKX7$Nz7>RV=E6c{ia z_~-#!P?Qu~s3?YE=1pp*p^o)l5ksQ#u+JWw>yqy?8Bo1+n~7tgT(}KXs`IU}qHyk_ zIcvLy-rz4>NC782<*>qXea;hWm8GwVNMsY!uS3=_gl|C`=ai;C2V2X zJ1foE#&=$>X@QVTTz+!&jr5Ye1A#yaL%so~upbDK>tfjc7>v2~nLxpC!F7GfJ0L!xbWMP)S5X-Im$#swz@*d?=xK2w!9hnwqBdtAH4@|<5K02RXPpFPv#n7bo;1DBmhZe`kPq=QGU#^6Rle?2#bv&5{qIC*nV? z7?bas@8&ZQyWe-iUJHLHepfgjqq`l_#jDO9kr+jZ5TcL4Nm=PM^GgENt`Huc;y&%bcwW4FS^(^^~Q+iy}- zX6>Ax9Xu-%(<3{qZ=@(MPyu;&sMa4|Huf{Y$5GU>0@=K5y-$8Qt9^@k(MFz6+G(CbV?s#r#Hk&lnSUoABA*LC-X;w{%5 z%A05ZfY$BkQAYMl*oFU-=^M#S31t|baAiV+FXjtBttFQ5Hd;X3(9{vpO7BA5{Gf`A z+dFGas5`huU5@Rb)s^y8QQZBjHmwm~AF{E;=Eoz!*iyH2Fn?6u{rn3|I{)O+yIu6U zO`wFGn2l^6VXiEd9ky0{JvHgS38@6`i?^&hd4jDCS_+O<{j06H>W{5W2>WGKtrzYGR_#MLOLCAxVGl2p?$l2&Nh_f|Ir z<2%lB9AmKFGNNmV)yUP-d)@pl=plR&Goj9uSa%3n97}vF@!@3I>EkaXOc%9MR|iZM zu4#GfM4=?RemQ>oVA5!g%f!ve9tkmf%1r4hlRh5Cv|&%~K}*b|v`q?+T8b9rs}KV(ODSx zoM(dVWVJX5KD#G0;SDthGOGK{_v`g9POnXl8*8pM3!Q1o}!_nz44_sVE47~M)WHE(- zT=!5alH$I7fOdS{W*R7a=!kwgFQw4q$F<20bNn9?Z3t2qC<@+5{0`bInhLb z%cCdMZ6)`-HVyQ(XR}hffL^rRH1K-n`(>U)yNI%zEhFDEA?XO%=MKvwdMh1`D(0XG zKvV@wb}UXV7-UwCzN&HwtPclrsj;`o*~XVUc5d$Ht3J>_I|}trl*nw$#^sb(_RyW z)uf*4#I>MT>AOCwz0|0BH15%`klP`70A+(TFdmfCArTPd84!m0S}p|1+cY=!~m?EJSuW`foW57VmHS zBlarOgGFu}!jb;oIr9{jVIlOl{sG?Ql3Jb)w*P=nV>evt9ovY7)Zx%YI+tv=*)R^|;IOm_G0k>$_%4@IMWd7(1~aWUNyuD&NGh=0y+-)1uy8_^ z7}*E$whMROv$j3Hh>g~}!`N-aAL*V8kIuAY{f|j=C;nR2F7CN%T1n(QHPGsYB+9ML zWMW9=M3W3D^E?+Bf!5P!8Si6E)BXWiH_SewwS`-eK&7W9-q=7aGIDISX8bJhYkFy3 zIy+(;g+)I`e6~{xynB65Hj`CcPPz}GqZ!DRr3U^|mu>U$EEO>m)7ollqIb}w(|e6G zTfi_=gC*y-a@Hcop*oH;RXRdk8^y;NSauX|GRDLho*FGcsKWY|i~XvzCILhZn@z}H zkDF_)2y^8loGHDMe01*@fs&p1OxH9&bInZMMrk|cgUjwi0&vWo{OHb2LY#3LN9G&e zMpFU%pB^03HP7=Eeih}!-2a{TB;|+7cdbobRs@O7Fmv;bqN^sK5R@06)aPkNu4zN; zCvWD}$Xr2MHo^u(u4@Yv&foWn$V(1$R$nqK-PSy8$dtr#*vO!CQ0VKEpqU+C95T#QP@Uq z=ZKj2yIiy6a|0$$KASSwU^p&{R6%HR0-^BrV<3#B&T^7L0lp4UL;*A*%Q`VHJ6HEeSts_ z_P-^n4seXc(S1!LMFa*K%1vpDyRUdZD>W3k8~<~zjJ?MH>%miDEcVU{_ttoH9nfHj* zmrw6+om)>0xY_A058B&;5r*Le@7 z=>ic)LsM;_tp^OFc8QIORbA#Q=r*b!A3Xhz=jI?3o}d|B^c4P1EZEo z7K{@#LjtXP&H=4$(-xc)VIM_?o9@!T&Afjdh!v>_U`s9SpSnFNXCg1HT}j_M_0(%? zRjYW7Sqv_VX1)Aa+xv~B-JWY3)ygUyRYX`Tjx|FFE!!9m_%lr_ngVqIID%vIb^Zz7Ci40CAAWzbUN8 za=e~vSh$UNwugT<8)M8}Jb8|ie6VJ1{@{UGOJcGzmPuaqZw`i4z4Za_Nq%NP;0ZTA zV3K(_H0|NxnOPqhTHopO%OidPy2%SpFp7r{(iAOdA^s<5UWl9MWJ{bXsE6$BFH`Y) zDqlRK+WQt05FiGRWapc2Dl065Ucvx+jCjd`)VcDwCQ_D(eVFYL$>Bu?MUG-u{t>%9+AqX}N+u|9bXwR7jQ5_p0BX3G$^WMw(|> zBU|f=q!CK+TLx!v(PEI!CS7pM;|tO0EvbyDI^ztRrnn6BPnBG0qG`7t^?Bvy=jC^u z!rNK&O5-G^EOv+P?B{D5crk1~v>1k?+`-jHW9Y*~N-)ibi{{Gf$R-ITU(*s3?;@{5 zr=Zdob1`!#6ZR6C-j$D`?Z1T6cE1L!AG9bz(VV}_Ybm;3r~4@M-aFlq6*#5K9K6-) z!<)lXrI!nPw`6}9U62u?lxZ-0J$9*NV;hS(Iy7(9p;2lPCx}EoxC;7jtnmcg z5`E`2Ib$36n#C!q==<;Qa_fm!ddY;%w-gxriRV~v?Xon1!0wn_Y9xN)Vq{3T5bqxP_1%+dSvdv@sO zW|Q|RFI0~6)unVXy5FDsq+5;1OOkyShBgBnirwEf53XfkstCA5qfWJayYK|4@kt_A zZ^05w?&w+Z%ZaxVxEG-7(vfMxq7nl>Zj7(D0Zf}x&rHV(#HL`lbN35@g=WBxsGn$b zvZFOspv1cB0Fa?bbW~~fk{JtSX~MPDI2{Vv&(s@yUk6GwwYvaW#Tt9BOI|H*;jgu? zZ`CrS&?ZAQRgYi0rzSy_?yhiauI-*op+rerS*463>XY zPcxl2%~|92_t9Gnyi<7@KDRCr+Xkn6*Y@F;Ka~+bEs1|RdzN8BsT#@ubOoujbxG&B zuBAI3*?&GI8@EcTTjNlm$`BBD*OJPapKN;vE^ih$4!u{nc(+Q*T69uDC1Ag_dg}<| z`SeoQ^TV#T(zF7*%z>mk1e^Ueo2ku+PcuPDG{4TP8snfXcN#)7#U*)3*6KVzaU}8F zMWm!i)@uniSE>+T9Z8^4NLR=;Sy61blYP&+XS|!IO0kYT(vZf9i*IMou13Jw*~18f zie4a3{3ALWljhY@{3+cUvfugP#>#cIo!PXD8}aue1S%rpBVLQ z#&Wlr`0V_v5bu0~?@r>WL9KkteA#FxfliPRab_y@_%FvA>SnB_sN6!E62w9JmP(aN zKI-I>#m3b|MZ%20t`*7xQ`*0~J-?j0SF(4H=fN%?<3qYm?t{e=EB3li+^Ei9(CQkw zy%3S+-Jp99*;-_MO7KPXR+<_6@8E~hRjnAuCmmJIeP1@*Ae&WmYu(`qId1RaLb*l8(AJU;8rKc z&W<&l=<0pnHu7k-o{)NC*uWm<-BW_OfiVi(MD_ErLo zu_6uSX^*>5HO&?&8MFD?NIMD6AY@3PQ8=MWESbApH*<^R<=uN zs>)l7^wDgGL((5*+M@)xPHj19PGz)V84H!FdO)yOvgLwNGMa4 zyNH{~#7gOZJ8zT%-f@eQ#nim9)H`*3sqG_--LQ63XE3CHF;>T36oAb_-8Fzh6*C3~ zVrm24bq;q&w9;kt3<%?APH1}2l z-fNX!HLloNn$itVxF-L{Lh#OvCFsPSVebup16@l|t`w-X!ajX&3SstB+WBdT;S!sv zQtzFzul*E$jnnTx1bUlJX>$aKwI~gxY*=YLSn&1%bSoC6>q8F$P?N%7=U4$_NqGvRV9iBBuV={wa^`pZ!g9ElLZB|Ahz2cY7~tj6%H8{1(-vM>tync^glT{J$YR+r{4?vWAQgI}>JIaiR*s?0 zVn?CX5U$lk9}3sebl@ceM=9M~z8D%`mMlv5BX|fE483tz)D@c&wf~_ozaH~yRw2rO3qu}jS#X43_X$zDRU4vR% z17+T$EG@*0dKs-SgRN5L%WNl zxP|N2oCPp&jrx^ks6ohE-_UVB?p&r(#GWTVbK*A98ObWcg25v%6lj+6J4eO2JGHORXiK zdE*E-?ZQ{Y-^K|qaSWkMopzi>(tENonmcqw}h-xBl&v>eC zPd#B8toLfy9^q;sm<@R}8DgweQH>U!M-GEl;6dU14~{WWGklSmWs7uU{pgDWww6_G zj%ZO0>KYVMW2Ly6buI0eKl_9RTkwB|{r`m){ih{5aegUcUOJ`zR3?5))5#`1>_y=T z{%d36ya{T=j3TW>+@V|kqgZlU*!`5rlVNEAtk=4EwHPXXsUP>JCGb6+N?L`X;cgzo z>}b>^*LCGL9Fs*E%VtW4w#0`e=$>h9C*`0;w#CIlwrNYY(^;ln^tYHA@$40QE4ino zV%-LCD4nCxQD6w;(rAlMD%?govhj@LEysobw1C(ts{5l+YdXVtNR#{^0LJ1zGKB+QlmbhKSAWAKpxTi>Cf#V0E=N1-OF%t4K$E7(6255u%o{s@?(prejtQ)( zLy7eu>Ar>ypc1ipl`h?5+$PEKAb0peS=#G#C7s?u#YNh!U@cfYaCg|~KJV_w1djS* z{cDgmBVH(rq)Af=;#!x?z3yT4((w147?j=be*mm|rg+mxaGI;VCF zRtUU_^0YC`SoR?E9Y@#um3o5<_AW&#oaf;(8iBWNcFR?$I#9~U(a-oFL!rbg1ilPk z?OpBq%Hi|QDUDczY~C#wC{M`slHKuMiNN?N%Cwbf9kRPl<5-TWIicjR2_pY3Q75BR zyj=gcyy+JUfAV(6-OG+XF3n1ajWL6y4VVYLK6C>^DWNiR(uS|B?u-FAy20Qi5q_it z0tuVl0lvTC*Y(z+lNb5rPfw9LkFa-eQ+2C1!DQ$C)?=h=mKb^!8gpHYtcm`vj9(5B zBm^lHwFW3i$3C=--~A{OnhMC;bypqmSY0d0ZY?*sS6c*4upWZ`-B*+DkXZ;ZBqc@V^NnJF2rV!d5Szh71eu8^SIuC?HwapDY%^j z><#X)Bi~Yn))rVcPQH3c(+eGv>%x{m?4RO!)j1PED6*g+ATYa-Y*s-MxL!7uDlXa; z($TmhJv@;}8<>7gpT%gs7f0=8RGHu~S1(CyXyzc7@bDBMkk#vN`U#%J)D$Biro9J` zUkr2Y8zwM+(#}pu8=O~X!m6Or#D$wyjiwg1s?J2=%R(FBS2CO85_Z0Kx2^gmyh!zl zx*oqi5X@L2n-+s+NWwd_18;}m*S6Od-O(t7xV=J=jHKe&+Ut=S<*k`+eKon;2}DI$ zY00N=EOu(7jBZoBBP&NsurhmNNip+u8Nn}3%jF;7BfsTq$4p)h@8;jlc@XrY$29GB zC{fO}msT&P7mVsN5iT_zb7?S zmfX+u7el&Vw*Tif?YJk-CLp;NM@(rE%PUV{g2{Q7u01TzqglccTH5*FoOHjhJ?;35 zi}gQ=5GY^2XGnz8KV4qrIpV!{@)ED{sANqnb*YD++97Y{=D@J}*CehLyir z@5IqUCQI}?$(2`}x;$5V0QQ*R;id5{?i2`NT* zt(>3w`}~%E@BWzX5Cv4W=d&kEC4;8&HT7285_d)n5@~l3w5b|2}1AsGPLrJ~Da`IF_9ESP| zXP0t+v1hK~;WgFUbDPtXKjYw4s>*R*{N7og!A&I%xv)s+_YYvq(9>`qDl2!g@zm=1 zgqIWyohalR#b2pbZrB1lTx`h3$fT4c9_e`7?QmZogQZObeg1(JSNPTpuI#z#r$Zf< z2&rFuu9|L=NU`GV7oB2VOk)M6(8M8@vz3i-v(Lg* z5o&cgBfyRA)6ZaE5A{n&#;z%`mr9Abg$$*zSa=m3rt&KBQjePut_%W{S1RY{mPu8d z-z85MURAFbVPZ2lsCj1!8^3n;L}sxhH_*f)N&gN+sBz7UFXg~`@Cw*<2rO*j1v$mf zIG_e>O2GUEmTAwQ)!rBBWvbRJ8K}PHmDzKycZIb+oGaUq9rC7*6jv&C0JW2~+C3nf zYvi1{HzVvt!o5Xo?VM*K`vP#qha}3vqCF~}i_obu zmWj6Rlno2`ng0jBqnw|3BADY&0H4zD*`g*hYinYghIc}o&d-!YwxH8Oe@Mx1=$ZJ7 z;pv1lo9F0%0NoXnCtGc+wS$Hgt=SAM`1OTgsKnoS$uGxf=xurOO`2pVR}oLBspBzf z(4!XnJ0;aM^lRn;J9;iQzYlrXVY1lI&Y$-d5oPW?&1ns-)P8%);!v5b{_H=hxfK{4a!|Cndbz-Fls*J%u-FJ2d>dqViy^xGIb7@@ZmgP=?P?)C)V4 zXa;b>feuW3UyhOGw#e>w!CMLC0v+T?XrGfEEe>l6Yd_?JNuERKo^)ZUBp7F2VV$Ub z53MD6Z&92nOT!lyiW*61VbE?>8L$G+O_;c1AVaObhoNHA$9Q|>26K3z%}qVFE7Bw zPz$&iBp#S?Z*lz8eOM`}muB~VYJxKYuQ3X1KLB&OkwO7b6ekS!vr6p5K%#Ye^0P%l zpV)GS3^DkpesUONY6i^*fbq+E`^W7-1`?K{3U4TRAf#G3hZYyn z?CO&HcsJ*d2qH1$Iq=A8`B<*^ag;38PPGNxu3HwxuHbEQjXe5`fU_3y`n6DgjCyNx zSY#EY3DF7256*Qxkv~fLQOR=M7>**}_$##g&_Xwh_}q&~zZWl#2zfotBO? z+jgYzx`?}^NxI+4Ge(fOl`i8I8{R{m>WXu(T6oVn0B%?>8S?Sv-bI#znMD-FtI4D( z0uGlZX6KK2OI^r>nV{LkAkM?KOVb?K_G!b#2ZUR1$jjf3eb5>!G#Sd0fXnLAU)2+Y z-LQJO)HmE0`3L?FP;76x_)wIia4+cXz0Y*aP8y}2l?rrn?q4=EZ!qIhPSd1hs>B+z zKh%y!WJy{hU@u-Q7Cz)1wfZ9}5V}>mudCG})9f)$(jG#g=VjX%UesmjFyR}al&?wD zdd|=S_S>3We%_wV!&MNr=oJp{lmd=g8cgXCgEx4ynNFTws329=&qTv6xU)Lpiqa0{ z3>mnl=TVO#+v}aA9O7KN(RXbyxZ)%c`KWnRN@B>O;ufoS7{ioGV${t$UjPMI86JBB zaX0-;pDhtnIQZO8B^_GAXp#78luY!_m|A`)0yPm-> zJoot9pB*sY07Z?qpsvzOllvece)WJQeiD21Q0l-pMptMCXXn zInd8oHkdC`jH?9s%1=i-nDOQ(YEFB#gFfo;+IuHQG;VF@Vq6quKof=0^URPq3XblV z!o10~$>eqs_yFOguoD5>lY6`+A34|?$jgpT|08z#!V&`=_bN|buTLmmJ=7MfxPBk; zBbdAjqvNj_gr85R|K=HapFNp}=)8LK+I^+KCa}~eMtN%_`K|~$`4FhY8SNp46kCNC z{xr(!Qbl8tE>)pnkywPa@W)6It~Zl<&qk(F$i}}Vhu?Mht#6LB2noDKsUR{F3yWd1PxuVl&YJDJXVKu zu~t_c^Mj1bN=Tg!P`JZ?M?X=A0_oWu1B2$gNDEG;N=X~=`kgZ}My6`gGi-Vs3XHo8 z{~j6n050y9W*DK4Z?ytT$x|*G6@ZM;EfaXk{rDXum9Ph00eJ#KGllIM*|s0!UpFfv z9;ty4J@7k7X#d2z(v4Vjxc+jS}J*VAN|w7dU4`{SAfJ6gPL z)5rvY3YjdCu;j;uZ3faP^)byr=siCIJVg8=E@ z)JwxcWJFXV;Q+Hu_boT*Yy7mvn<=A-8K;i^23IR?_wk>~K5LOkmM-X_74-eC;?XOj zVA?**1H{ydyNQ0O95;##i{?1!r_yCVx>wg@J3`?zn4`i{)*IEG(^5YL&i_-_z~d1i zIN9`&!xzF0X4fF@-Szl#Lj71)L^q1!Ot>%HL!>}g(!s{xMHlMR}<_C_BFX_4S|DPp(QX?(IJjt($i{sH-YC`Viyk772Eq0Ew@bTHUQ*DU{q# z`n)tB)PV0Ty{@mH9#4v+E~%;UHkhh+LN~Iy(Hjnnb7_R>h1MoaOd4gfzyiWxxM&%| z?JA|6Paz_#iY=7Rk2tPP_al>;Vq7Dvyv`dk$Fy-OE>3ij{9_HFi z%o0B3GhN1}`yVwKR?ay}^qKne!Hmbw`|lIzQsdabc~o}gF4sz87D~q`+>M-ubOvGd z{{RbeTpc}IK?t{8AHNY-rkXRMLkLtQcAPPK_Cv5Zn~<+%e)2V159JWl@XM(FK~fuz z&1(UM96}rk;0+gP&OaGIW`~$5WeG6m~A*I1x+;p6=K>HsX5; zF5ZK$#BTKuPcm6XuPBA{DYm9z8qswqyb7)KX`1107G6*2FAZEu{juKbxo}E{6#O;j zHzz{93QTW}(rGX?UagBf2MQvvh8=@{fErnQ0h$eprT`~}vEBc`-{aieK|huj-dO+S zBm-!_;$&d$!t_(EO<}{Q>P=k4&2BGIt_RV=)m2KQ`PkF3@{z@6&BK`z=A(K z2WCG116EjsNIaaNvGloODZKXLQ_A1kxd547x>69|?X)$ZxBTh6v`+d>Ob-o1;8s~s z0||WX#g@OHAE+Az0_&P(t!M*Mtp*J zQ<%L#yd@|g;pCEiYVsf+-T7{jqaR2AhoSR+XY*~tcx%On^JUPCNGO}WDTwb^3VyvFfgdvVivh#Z2D;1N?VKUyYKdO#-Hj?V(0m3_ik!BZv2ajugBjk^mzo zPJiLfq4bRYRP%SWl!MYQzmahWwEkMW;Im9`0<^zd^JaI13Z52z2>CpB_X^c=b^?t9 z(0@1ZNkn4u_g}W=FL;2(l_Hj<5sG!uhIsdAv$9iekyne)lQj#EuA6sr=#EoeCD}rI ze_1fu9pGlft4^xzzv0Y~j;kQVTcbGwRKFA?>Xb(T2g8<$z#vUy0D76d;pevzR=7_vUrI_#p7hTaRg(K=7#yFEu@pG{KlR_ z;d{Bg%fxbhdy@ku^Z8%0&7Rljc>c)rN-da7J*<1^lJWIzBO!?N)6#wTL8Diiqo)+K zcY&V@`}?W$)*GFMdzpfHk4G;;x4 zAYWs#R-KE-UY~Gccm4IGd8)!IA#5!&FF*Nt>L)L~qyx;PuE!RYT3v5hq^r&!gU z9KYyWs=pBE-2m7KQ;^7!pYm;(jIW)r8}`i!zs8mc5~K4hRJiLDbcKcn(B5f1AqEzH zGoSN#t>!<6DkcQ$EVJxvjC=Gcu>G#1`_$!+0P&^nRpcWU*K(jS9sY6zWrSN8wK2!n z_6`=)yaxapcR{?zkwmOA5yb`jTHZ7F_@AB?21J=;tVS>0!i)76Q?YU^j`-w%bAG{6 zJ+HBF9E37jL*LAGGry?0l<$c}Pj7lrYfNAXfte{Yvk8WWz+fDs%!`-sfIhXA3%!A$ z2Lc}ODkGWpHt0S*h#x<7ax$T#3PxHytVwEW#E>9n+??glXLV(}pAHeE2%AW*Uys(M z&FTt@yiFWx;mh>bU9ELHDM$&Nk*Eb{tP}}_z#I4Kz_HwwY{b(3LWoBbov|Cwx>U7= zOlGEngWD@Vu{Y31t>|P3?Z(DOR^`ZXvKJrJ0fQ%F?GG5vR5$8=aN_RPQz}iml*}`A zbB2f4EjSL~@)wB-a2r!zdn;k=e*lnv_f!>|hwRkzpn$SID(}L+6xuWqjj??JZy|lq z?28SwHnF$5jsv}H38+ILaHm=mmB|AnS&kw`e)(%WU3{2u zBJWWz*BGjpIA!a%thqG}{xDU|=ijHl*d>g@bZfg=5832@sKm=j8q=9)l{~vvuolt7 z2En|t>Px^ZP-RxLX&7p!9)=Hz2Cv_ByFkCd?J8r)tErjt^}~hkv{6=;9j@!}V1a#2S9cWOL8m zxWk1RjW+GiP=%s$%5nRzjftv@!J%F*vvz6$_d)R9^ri7ip^^r7VXe36cutrVBT!zX>zx#GT63<(Hbz%fyLX8o&X=&EX zKW9R`aE{lgow)UFGrO{OJ|8p5Sr{4fi4|Zf;798ED>)yfnJ7B+GoVC3sr0SUFB2fp zw6U?5i!(=6(!ztXr|j`YEP1exsp3-0)}Z%A7j)U&o>w~Tc$laX_y)|vbAHEdZpTrc zl$mXdr-V8LliQDOIN_?EvQ?U6E}MPBH*ytFkG+7~9t!ZX8&@uwh`N$4-sy2D zSK9!ze!b=N5p0p}Box3{u^j2Qaxp_seYm}O z&5@_T|4K8;S*vdSyL6_)?$>46RNJ9;&`iqcz6i~Mcb@O<`~ zD8{!WRX}OvTV0H!$w29WbtmT|s@fmJWBIbe$>yEDT(;DxAh-&fTc&{c(5!Sl_h`P3*UdAouEX+K?ZHKYhi1N*O>biEOo7;2 z1pgY_J=Ykua>-?0nz(t#v*;Y@Mz&|7JVBEtgGH$>3QmI9tZ!d`CST_zxnk+*ZzrE8 zdW@#@Q1mWUa4f|vx$(8Fs}&KIeY0W4;lRG^$h}&y2#nj5FTMSP@98msu&tR{Qf0L( z`Cw{YZpx$Ib^Mq4yDeT>mXLBR!L1Xn-^o!`20muWjzokw;@=619+kMl! z-YZ`XeW?t(JC3@xSx-6P@+C*ySFTIDc5h)(^LhIWdWxMwP`XcOX5yiUj;%WSKH3(J zda~1f;a3zR6&}Op?{@I&Ha`=trsj=w#R-WsyLR(pr)0u`2xM;N!KZcA0sWCbla7ut zMx-cmA{`So_RqTR7)f1z1Ru5M=O`QB&?Hij)}11ecc zDxGW+xvJ*o944nUjS`E%=>}#|+SifOW&@MAMwH|^35K-Wyz_c=iW4JI8~hwYt}(K2 z<6mk{J{K1Sdis(eFeEbIu>MJGTBRh)*uHTh(b(QT>TA$S*%%PLONH5`U-KXPLTAW7 z6l`Br&qKm*C*K_-zz1+U+&<5Xb#%osU-9dG?pS)Q!hl617YBO-V7G6_{x=un80;Zn*(k*m5PK(~%|)>-C_@e1F<1~>#ptYJ^({dAVY1#TV* z!eE`2tlUcfP|Tl*-EPmK9_)K^rYVJ7#!*WwkEE?4kf_yG*Y=moQ?l%%hbsj@tGdM3 z>Q(LQTT7i-YC>>88^_*GabaS_W}HS}OeESgFG9{NzMA>OT->nQ)ummrg7ko!RPxhs-+Z>0%nMg}D>$iJCLS5Ys4X?=Zn9?%d8a30I?Q077e2N#Z zHn->mW#&u$by`TcUE59qjD7ylJzs{<6?<+%Wni!ArOD-K7zPgX$a5 zNcsph0!%A$Do?C);jB2z`IX#8;!k$sJS9CkpF2F(O@6d`n$TD7P3NoLNcSvzNYtRx zXz!y!D^n9DjdTE~VH1B7ju`QsBGKU?jlA!$ub^|*WAx&W*a4}I>X=*c$R_;1ImCs6PcJ)XTSmWyG?n+@+n-D-X(k*Drz$BLIHrtdBGH}!9 zVG;13BaGsxfY2^DKq5V~_swiF{aU(mj+rX#*tQGVBuj00G9oC~Vhr`@3{<=C?UG0# z6YMqp2@R@`s^*={J-Q_>;)|PI^9IH4qcGK5Y|r&nPnjfGl_FhJ62hFAfG=-RK5TuL z4ZLNYyRBVmwTO0{RNs4i0TC1brQ!JeTDatn7}z)dHqReqsZRcSiy_r-Dswc{9@F0! zQy73z!unwPQXy~j?lO!$mm8)K|$ti7$Tivl8{s$rq4O~^=Qm5_lmLhd=eiN7tbV|f-a|DcW;8dN4G}SAA*%mLD z`60fL-P~MpPL(nH(ny%JOcVNMDHcrWisd~7l^H2w*z@WX-txychBt66Ito)Xqa1MP z%e$4L>w5h;nxS}^Ne7jmR_Ol$g7SE(v$oLfx7Ac(VUHO1dkHrCLzo4-v)1Ry)#LTOZ9^y~k2T2PHz z+EMcAYX5{=MzoXSZ~N0qFfaOhYR1Y3pL5ZwWv&Nu&rHfe{@}&?OJ46ddd8}o8yI2X z0B&XR`m3WiwS{Ck{xETD{76N^lE1@4OLn#wZNWW8oBBNCPhq9+&V1&7B;Ql!RFWg* zBh@dBTC`75hS&ZF*a0^>v>?I7k&cU%-b{rN&yEG6+mjq8 zk+4@*!-u+ZVSlBvUle@_^HedUn{^!Ne&d^F)SVD&Q#&^FW^VAW=}hMT0DcYIhA;k> zx<>3i{HHX=*uI}?C9|F!`_rOCdc<-<;_#3%-t#m~Oa{k8w7v3f}{mCF(#cDI+V?2+q z?YFG&pixk02zaJak^#TB+Slo+b9|?=JRo;S=*JU2%-kZQoi(N}6TPD4$52wq`dJiS z6lTk}vgk0;@{(8=U{t(^D)mg$`3u;)P%}k3Z_bggGd80@lSm?>qQ-<5Hi7=lQ~$Ip zSk~dmD-9}+*Fb(Uzc$11p`{7WLl2=y{elO60S4OtS;$^xcWY`=2$;r@riu+|0Ajui zv7JkEUa?B>So-p~+?;c!ueK+@e87ToNX|jS1~bEe$3@7h9Ju0s=hN26$)ui!Z|IkR{9x^l5|lIeP<_!H~KWSA3a=+74*`| zdz_K7L`B{m&P15l#e`NBG!+kds2d5^oMbUNyXmZyPaH$364WuKH+b9-gWWI z736qpDw0?#<#g}8Cc3Z67^E#pI837l23;ZOSGyu}F4JSFe2WmH#d*d4T`R>du139# z14_dMKVzR-(P>j7DDN%uusufyhH1@x%#I|^9A0kUnI#Xr2eYt3b$j z+Z94RgPh5}k!SyPNDd^)@6h#V&a07XOtt1yVyK{SS(+cIV5Wx|1Ib|Di`lwZy^fa=G zSLtiZxc^(!(aD^*{#^z6yE#ZVJ3atA{YXa&dQphrQO1^K>Y&3{NS5;5leCn_i_sfx zF#H^UL$oc)hI7aCzd@%KR!4bV@h>)q}5@%_aTYTTb8we zBQlsJDOb0>Rv%i;vcafxwhF$&)xe?@8xnZM7Gc02O2pt~m=f;0nR;PAk$zcxjYCiaP!U}36m0Cf@9#`&*+s*-lXfPSLSsGFRHzx*C; zaY73DjCeUabR2YGfxOjhe&IehSIYB8;;&8=4e4u46ha^q>H-0YXy=kcnx30v zq5U80IrG)70zAf@%i~2Fxe(5cPmn{xC?Y8zUe=^Ims(&Z(MxhqiT={Pmu|h$)kT&m zhSFE7U%IZphAS+?Y-GCAMHcB=Y73Tr5;AT3Iv}8THn_8Y^-w{|LhX&`Cqbeq!aI0{;8OT9C>7o$h6qilVko z>4d-q>+gyLTu=R0Gpvv*v?_k-Jxg&p+Gwq0h>xN4B65q0No1V|LJ%B=uQ>^}`23Mw zUo#@TG2}aRo(Tap{2UfJ0zhBb0gX9O(GLN<0RkU%hoL(6EKAf>t?C%~gKIwmCO!q; z5u>dOxnj7mXT|OR+Kvvx*^EC9$cE7|sty3?6D8-jz9+*ahn~G{OGmSzf&JNP_h)hY z69}hI^K_RZIRf?D>cltx6@=(aXcY+D> zRW(i~8N%izPlMjQ5-HEAuHDEv&=A|iVCv(rsta!=fXtgr-;&zx0+_ipQw?$+8@%A* z>)qh?ZqOP%l@DR{{(=8+EnB~>PmXL)lUu5><^wM`CRzM8Sr zC<$6bRjyiIaS%Vgf1Z8ZADFNsm{YxNBS$c{wF&{(s%V2ZEKM&gI3xOg@PLV?qO8(J zDn*?Y^?drW#mn~0lFvHVvuwc1WFN;fx+7_?ePIWaU5W{IOdjgIXhEpS9{=ybj#+a=InPLo%58q017G1XW?WoBhJ#HBhF*s0 z2914)vd_t=O(-6V(LYc>z@|VVO&m%RcimlD^zDb54kFBU%tI>?8jLkZ$!`Y&uJD9F z+moND(1_j@|J@$TMrCe^1&pY-olO^IHkp#-y^YWF1gdz43>Kqdy?dC__o<#nnEk=Bfe7_HZUB091{>yc3!$y%G&<7XS9&s;!%yN>(@1npbi7Y|D1d8&xpR~DGIj4 zTWBvXX*j9RB;fuj!AtO0)PjM5PBS)h*`xeXYS6Ew=?Z+UqRaK%miQ0x=|H}(W(R0= zDM~=WonpDJ5%poS@StSbu_TyH{d(DXPvgSEcCtrbt-94KP23*1^Y@)mkhW$FZ?Yqr z9v(OObOqkt8<(!IY(!?GrP39ghas{}nTHz|=6=RtG}ucCxX9=K17wUYlM%;`#fa>` z$TiEYyDDjO7^=(-YG0DMbus?wZ$D!5cGHbpN|;?GO?kTg+pJiTAf~$U)?xZ(6MOT0 z_I_TUqqhU0@`un5`!dM=SH&G!-2COc91e=}g?q(YFNG|3Eic+vtAZ2)NtUXW=(8Fs zd+n*U+I(T2|0_|3giPvW@yCiwjcgzH;r@&vm$Jm{>A?Nel)ZO`*#>YAyj{B=6fKelY-Ck2|k#2S)uk{lPi#k2IF zVpC%T3aTWJW`4esd-y&f5{FFU+KtKfz5uxhcl0UdlAZa4`SYCZXe1BHJe&?G)(N<< z9l;CvhjOji4(te1B&wHk=Ou@AEsw}lpQV!fnh@|LJ17Kq71_)$dckczY$VlKAv@cQ zWo?q@%G4qXt>*f=$XOHHX=dH&)DmRv0{(7Et;9c9dg3OFb0?f*)U~}pNWc#r zBbO#O6F*K5b#!UpKJr^JUH`hunN0)GjXaI-`b=!})Vb?fhuMNq<%;rM@EeuqXajjmgxV+UF`N10~BR3)0#FDBj23{cWVPvA)9532&O8&3s zbEVEGDgU=xR%0!kcq@Xh-sL*^u;-RK%DFCNllD>vZak#3IY0pG!|fM2r&P%bb?L`T zM?iS@Y%w%SNoFS1NZIt(y5{P)u0?w_I%nESwxqHZF84SRwpVA3up*T=a%9Hp^|Y&S zAJb?k_ER)k9zr_>y+p?_>naP5%&8$Ua}^{mI+>Xpei1on8*4Fw1h`h-8`QA)5roft z=k8zN_`}Z~md#&kLz)Jv!*bd_TOIt@VKZr4tC>21G9iT?8ATjH`A}zkhCLl zTVK`Xcw=c}BiBbca2FJIFu&D;oTK9Dh(L$L<1ANtG) z(CG15yrLwM4ekvWp_f?1flM2q63>%AML_KDv>>jkn6**rj8t3ux);vcMx>CnjB5|E{PdtDA zCyr3-n&j(sv&~V~>NWedIOcU#KrHb?pl}>9*MBh<-Ir;)zQ-*jinQd^&DY;malIvnH zG6H9VN=_+wtGvJ4%&RsCasgM)b0}%Mzj=fTQMij!5gjX!**i4bRerQI zTia?z3Y6l=t!{&7(k^)+026qmP0zeAI5nrq5Bq$sO8zbQtypO{p)eFHu}GsCo$<5G ztL;;e`%<4VTW{Ad@2!1(zVd4@`KSTNQuh0tC)6|LURaJL`M`Ti;4C)#$)LoG#8G1~ zm4>{CR9I3wfEUR zj+)~m$Pgvy1g`d;=d$cT%*ET@aJ!Ys0K89RMkYQ_+x_cjJicn+q_krmrYXg?r0dIH zA&4{mZT`O*ufHzB0PQ`y)N5>DGN5Ft%~6zPh*>Pg>}>!UY}3x$K*4 zN2fpJ-+s=0*Q87b(`l8tIR9QQrx7+q#C?j25N+kDg{hVFU;NIjDf(MV zBR|AxemCf33Vo;fQ&`UKWUm_Ik~KlcpH7LI)LAGlj>#mz@R{nA_maWwm_Y$P_Uv2X z(x~KlyiX_Ds%X+pT|m)OOWU}i9=&9hNZ?NsIn+G>6_kfvHAXa;O=A7m6GEz^XB4;s zszzS8>_wP4EG(854kCfH%wVp%7yD|jOM~iNo6OD};MPbMOa#d2bu6XI@GzLSYw!)!y0AVykdbftm4s}@ytwT`@w zy0P}^a&!{6d1uW%@leF{S}yFyB2+;g$uSC>e^{5avYNr90Aqgw8;%hdVc@AsMZm#K z+PZ2VW?ih>wJ>`SQz;OJUAGlraT_l*)X=;Fpl@ia`&8sHR-Bcj0(`?~tGQyu-DoAw8((2b2dLX+t|hNjAu zm5>lrXIvLupE5tW<{%zUW7=>%@&wyiD)cTMp65=w(d#5Fg4&8VK7n{9y)|`%BD3`Mld1xoME7ctp%F~#R{F9ReI zhD>q;V``7=>p6%ca{dl;NGm)^g1(^f%jE|zVY0*u`=#TFUTn^^NgEZanOtYB>!2I* zkkY!N7obK&JBN4ep>HP(W$ejT{8&u6mh@ZbV*?`@eusmEJd#$w()D9wV3~yWJ6nJmdZ*J$&&6!lk?v2nKkpxN1 zsucm-DCycMIEVa$G|=F|&Q47l&;g-EDQ#;hLGh+s+JB z+sij&l#&$DU$AW%kZIh-on*e4#nvjHds4Yv^DL83Bs3LT?{1udVn@?K1OVcUR!wOh zg92>b48-76``od&G6~c!?jrEb#`y?qPA#>EuWFhs5_owD2{$aZ9xj8OzyU-xFt=F7 z@Z-_FS}OkIhb`7T#1o?w5dtAhW5e&^;Q~r4a(~upufU+Wy-nIzy7qqnVfiwf+q7$D z^QHs{$4TMn8DDp=xLIbZj4uewR6SUG$AX3J6>;fO4LixgOeE6L6u0P1| z4Qb zy6dn2xo0pu*QSVwLbY<*PE|dfq{a75v(vCTMLXp|{1S z5ZNtn0*so?t&4!bWV3{W-AwG?FBab&^q`wR5%#E*#%i_uz3*h;aVd3I&9B3EnD^5) zo1xjbkw(6z?Amus7580rq2|AfQwE02c<{_mXWqJ)q>YX8d|BDRg|%slbq$iAhq@nbPG{(OGCogQ zt0-s&u!#lS31axORm8BF*Kbl(r$mVhx)b1S^xsc!Ixu&UEUm*tEP0MwdL_7Tv#y7KrgIz`F%5Y&99)SgCMX5Sl`C&ZuevRcdq#V>y`&Y`~QwF2b}s?dR!mBBQLqQEV=F9-E*}N@v_O z`XnxXiGBN^{QFk|`t~cGvgu~)?5)UyMyyOR0Vu%yGtOsTif}x&Cl}E*Yu7+)LbAm` zGTzK7j;auX?W_jbdl&4vCr|a;wcNfo1B_(0#?R1&Iy&Uk%g#>QG^bww-HD+02_1A3 zPLJtDU0zQ#A`=R1c0sW2ln}Whys|IQPNKInNFwBLNxocv94S7@_6Gb6Tv*KP#qHM- zbbLTEEj#dzTiN4>Y7V&OQLV2i8E7mvg-nK5y*VL_uP&}nX4cl(1fLe;C7DOh^4>l? zcGXrE%G5zqQm|#D5GV6j@&73+c0=;?#k@o#nAMUnyX96+e4cKp!|BesTAVwY#?Az1 znHwazbUUtszf$iirL02)GQWK&Ddv!oxI*-K02p58|7QcI)qPIXNF!Y?gI1J(s`z_J z1PZ4S20AQN1#%z#7{`pC%7Fh%V%89I7&5DM)J9_wb`V_Q#FL`AVcKBClT6UG$mr8{ zws5_{T7h2g*XN2iOtS+@!=O<>$wZ5jalb!4?$r;aL-Sy7g1 zl*0qWR)8x(A1IBDdyin(^>L_1T=rAX*}|MFy?+(zp4Sb>r@csf+!Pw3-=F#Q-)Gmw z+*hF%4N387_pi86c=;#_MT0eY1z{~zJO1FDYbNs5Ak`SV<*D(Y*~rGdGM1k98zR*l zyb(X2DTi>lSsfa+bsS#KZ`pBNDb(3lk~40z!fYr!VwzxeFTg(E*F)+XSyc=nd~4%A zy-}am?pF%GUVqOZ7#fZlb$Qrq4bic$ya@36earR89Hi-P3q=Gxm|V^H{#F`qhfyVw>)_I;G-?R8fK zX~K>G$&mQ)Du@7UoVM_&H%E+G7R;*{@g=dP5SOn@o7ka4hh|z~YLQ+niKYF!DX$ufmnz2bX&r5Y{k5 z&}<_6&AIyVKoqFB0B9U&FY;SI{>JUTo$TZ@1G9GMi;|iFYYXv_lIt?;DN=5-iUUiR z95mm)kw@(E7wt)SwbVDhYliz)I9%6ZlP|&rkbkq}`P}c=O;%P3`aG^IvW#v_cStIRjlKZlBS-((z&|!L!Ct|*l0SIIi@V2?j*KzUl zA-|{(??A|ad@p!zXaL0~MqqZd6L}q3B)C^?{I5NS{ASSW&DE!0>&j~7-h&fYVzq+Y z<;v5+Mvb_H3~AJ-Ppk^=AI>vm)CBF8W{#)7mpaBBG{Zs8;RI2=+%h_Zf~CZnK>M=` zE3*8={BDF%&Q6Hl2d=+$BggL(vA&v-p%KyY*R}iYiC5sG*On?}-$oJn1={+NBp6V)cU|FPVsyMpGIHLmyG?HsMyHooz{26X zB=kU^QX5osw=YwfM>UttVl?Q)&x3dL?^K<~1w1xA=3^Mm`^3P%Nj@57FWm50-r>({ zizG20sTBEpE@%)w3i742XLU?+M@7u}L_qaX8gI}6xY)E2D}T@HJe}i~EYs^E!?Npw z{8FBG*AjRMV4}|QNjK3m`l8N(#3Wdjb!_;a*hpJi8w8lU8acMI-}=%VgA`gH0X;EE~V%AwSGRa-RQlzKs<*f{kmKj)>zrw) z8i?>U7m|0GT8+={F_SM}yCTzk-PJI?3puI=d8cV?#)C74F7#$QVMzS^s693ZvEg0I zt7-ExxH2Sk27w+%gQV;5btx5V$=wT2;w$!E(N|Q4W9116tibv0SMgcVFCYjrTD_`L zn8dy>Oc28C|4I=wdb4DraWrTjLDCs69Rl_zhK=%=YrNIq>&G;qZ7|&I|80Ed$&)LP z4_S^RyFOzSR9V(4!AhbuA+UWgsb$e35=#}%^9MH_q0)E^D%kb}=cX37Ckaqb`3~Zl>(L3pu#=Ult-d(t5}5Jd+vazI$_hW{ zxq6M;y)C_+?+^P6e-!U1ECGcq&}y*_dd+; zFQ-*QfX->l;rnv^i0uR;d3=vz%`euYaFIF_ra>C64XtRCrmL(x%%zZGO>4taf3JBo zkI%d$hW5&*YNrbQNF<3-0)ivDT%li_+;bOQ7ts?l6`K_6G*sTX;_xY2S;Zx ze{^`Ou)+HN<1+S@g1wG_I^c$Guqaf6jWqI4man4pOr740g5ZQQvKK9x#e5WLTKbuc zo}MvrIlR>zbC;Jh z(e&w>2cKundIPvE<8#d=@(H)PwC_B!Ti!L2r~shSMe?;$0FeyX|GvaqDrH(v1VYL{ zRl6Mhng1 zyP91tS4)qBayj0s#j$0fi@c*#b^%7xdw?lV6AC#I>f|KYJFnq^NSOv5f*S|=8YQ6L z)Y7*JA=u*>2zn`ciV68qgzc748P?k)a~eV9WzAXd%Z z?%{W|Jx$k5qr%y$>6&yazQdiyPWPYD_ncmDQsvVR?$EH|#%1M+cEuZ$Z`Mb_eZEHW zoxwfFW(Bu@hvx%VsBmO zXNXV`E9bm=Z}Q{G5v=F7R#5#&W$=R1k^;@S0u9-jC)*%9rtz3mt#@|!-ldx($kk>3 zhuf-OmC+!&@ApsU7f6o$yttIUk}cu~qmklq$PY)eWhX}Q4n&^Cpvs8- z@DhQa$=Q!LEQYI)fnb7-Q{+2^-hpW8^gtOM?9Cn5=K@X%tMLTYy>cZ}HQIU4Tp)Tz z0fXZ^)yedZ$#Nd)R!-~p#h5KBx|+s64L*d|l=Pm$q=JP-B^`ne9m3Cl~rtkZFOWM(~sK4Y!l3;OjP?B($ zTl^g={FnWtKRWWDgaGIDnlMT8PlX=MNLTFK*Cy6`;max(l8tsH#^z8HN_f6ZJTLUNa@UFYh#4uK8)>Z=PG}JG(UTBcwGZt6 zMy|-K?P^=J5FjA*f>DLQPq)gAj+$D#BKs^no!T!z zaT2sKfzNdG~jt*@b>HvzCejRkAjjpZ)_Z32o z3|hU6mo=TCLV-#Hl>$;({X@zw*F0#o^1$ArOE1n(njQmB3HgnZ`W;Jw5;3`-?x6z~ zT3{LM;1^ndrOhm@6#es z4S{a**(zr7KH>XDEx@5-s zKOZw?XJ7EeK=|j0%)MW@g!X{l-`DBNa-=in9>!~g(9dUP;n69D6v>o?Uy+@p#p}97 zvz4?PAKR%}IhQR2u!RW4WRaQJWlb=#*@gZ3%*o(T_SCa8y#(0Ji@=}gXfKHgb`Y>g zD~F@e0g9_!>s>*z4Ex-5`40EdztV%W=n{R}rG};Kfew@nwW-BFwYx|Io{hHw7YKe{ zwD;(0n$#Um_kf@3fyXk!V{9(t;+{u_y=t5GsAbA zuW|$bPtkceQvLmZ{Nmczy7s=f7nf^qDO;{OGy1BKbH~V~U``&rx5-zd^$`C9EPlU41oDw*IRtx^B(@0Gj|PmJ-%Y z%!Nt9>t#qEfDs1iiHM-`Esc}3gSyV`FNiq=*P>WIFRStxN{=hq@BoYzs;s0EJH1mm zI=r*2I`3>4o3%#n^S@yvJxTeF6_-6)Xl;&~HTQUrxVRurH*t$Oo-vx#d6(kJ`;TR+4U4h}BrGo7fG%a=sML3ng3L zc*imD0SeGJvUW6&$kNCmH1z_ve+&GM{PpnwW%?$nct@ihPHX-2*K`3X-dg@tbV-L& zA*0Zr;A2c1`9(1*-T>WN#?_`!rL$|A*)9F&Jyi-syzR~8JJl|w0@aK@n+x;~RI(#q zK}6ScS)~gh!AZ$+_fV@O-ALGGgPL=-D<`7wTuQUlw~?k)S^a&gIwoLVj-jc{pgt8m zN7lJ<&gOV1%(&hIeAm?AR1fW`KzQq3RjI$;v{3Z>AR=XM|6-cS%=oxLq!VQcsL>mF zw=U6eXu4QofwQd|4%xVj!f+1v6b*ZH6e2;@^hoz--*rUK2Ai{%?;)k3U=yQp$J`ky zbJ^kR9;?i8({yB`$&zz_HPteLnN3>YrtZySDJM`##tt~($?LlIGH95tS1z$sGoaMx zg^S=k=uqmONfEhiFjHgd-L$`nWCM)eib2<)Z9i(m+a|e&&dMI!u?=7NC#J;S)H2*Q zKS`EAp@p2cwNu^b7GKvf8sUY~{$P|~-*6en9Mw1?^z~Pbm%fXWI*N~%AU^iQq}%cC zWBA0GnfZsj#FDPM*2zJZut7GP77<)$<*+tty{7QZoakm z-tOYPzrMPO2@HA52?l!k`b@0feHx8Af|@p#<<)C4(R;Qc6R+vbN(|U;@GqMn{hwbB zXJ)=V+`y_Ymy_8-y=Q9t(H)$W1KihILr#R%3FDC2h4KFYWLjU-t*vpf#kCOuE33Py zMPq7Mb5mU@3CqZS#e9$N(oA9ZX>eOypgX>i&_aZO^DmEgI@AgFm74?p{vh@e(B7T9 z`BLdXf$^`!ae%CPHRJAcdB7R~Tdlk|MFHia2MU!AcNjlls|qZo55j{BJbM;+*s@5e zVKWDNK{g0CoXQqxQefzVVjOi^NQ^rLKe)P@rc__ma~!Kv80+5jr4}ksf%yRK?xf5pS_)E^ zvntb`rmg=Jc9TvbrX2%^nW|N=GcVD4M}28u-?q%<_~%im0@C1`x%bO&D_KIe-}&&j zL{)#{UwKeYm0>#6zR_-vSKhoSB&Yc>6D8$QE?uJO>p2m+Ky>u3`)80>4RrobFjIni zaMRTJzXDk3<~Flp1LjbGzm^b@c_*3nlTp?~l-|d}g5qqQ#esVLwRj+w(7VrbS1?!e zEm-Lp`qix=jW;qS--mYRHn&5LdoOwM^heUrlCoalz#(v;^m>HzeLl;Iz3Ffr_${nu z@K*4h_lU}OQ8yhTyInW!O{_V-TnoWmW<{!KhdJtK;I}m9u9t~oR!2>RDFS1mu_3BL zgcA0&80JP}&{hPvTQUKS+@xx?YQ3EC{wHTzq367ehhKMPaD5Sa+<@e@;aJNl(>FHR zGR5_}#i7x6C%`uKNTSBYQM-?~JC87E9uqTiLcBt7tGb!iCS`pL`zgy3BL@a=P5APL~$rl^O-{9H27sA1SOWz6Dh8fe4 zW4^V7JcP4D0QpSz)iLR#Wi^xV$x6Jr4T+lmUSZw<)00kK1{^v?4@|*4DRBEn zst33pO0>3Ys~z;Cu8PRcJ%oqiLZ4y*RPtW{*G>S-BFQ*#UWSW)3jT&7uLAGe~TVB}3 z?l1KTwU62yIt)Ixcqe~3g0r%fLT4-lipLtXKiSk2I-EOb8rk|R=UK4cQH1Z4Kqb%C znrQWom3HHnFHc@1sHOo431E&gc{b1noUvn$;~P=NBm4=$Ov*2Fi(KH2f2)@3*f;`u z484v+BQ5gcFx0fP!3V9QThQ%lp9G+qxCceA{}B2lWNh|4+gqBae3CJbCMhkLkhiCD zGJwYzZ}L9crMnSF_vV{26{E|u@ppo7D-o&98#;+(n}L^l?*tjl(R!KexB)7*ain|b zMw+8hJ=y8=T>)AT_Tb;3ptaHIQXVG404YUQU`VJ)Xlq$dH#T5EJ=j#8Hx<4WE!cII zL?wu{k_mivjk(E4W7(TZxbH&_FwY|a=%Bsr3lv!!k}M;C%CGWBTLR->&gGGay5d*t zF&~ns94MAQ`J@2i=%Q(VDkmYW^ERLQD%8+Zsh;&KDdLQxRjS$eu;1FH+#no<8amM~ z>$WL&*+uRBbmS#u;OO;iI#v)L?i@^lrCB~;ZkYN}vShry|C>>sLU*Mii%M&NYa5a( z%faFc_iCJ}s|)?Yi$3CYGc8h6)D5qo;ENu@rhHvs>b7zU@$L($ZJe(e3%q^*pPQ6pLiwaRf8eR{mJ!g3rnMQb8L+IbwLz(f8y_mO>*|()af1PJFdJ|dAffc&vti)jeM;TLU)V(BtO4LK( z!;H{xw6cF)YK?5J!e9nKzsiKeg(&y5|5U~M9wu*^GE*!Tf5%m$Dg0Bp|9L9WmNz>a z8Jsk2nT(wHX0AB~eBW}q$$gE}+$1#WxOTb!c}y9SWyAI-rShQZ2i%6Pk{EA+;E&e% zbEmX9NCX{YQgV4)Y-IJD`l9E{nC$l?dMppC*4TNDy?~vi@SVFiELiKGEH%{0DR&lY z+5hQU?{oBrQtE$>8e3EB59Hpo=jptc0N=)OkY3)7!a7&Wmv#$OD{T53@W@)3bobJG zbI%BwIrsIMY%3OmNY@a{iT;}bt|geupe(XtvZxM~tb0hqYD5;Z{JVSdAkDxS9Vuqq z7#79b^7EhOSyq6B(6Bs6)zC=drfy(~@+gJ1_-FUFJ&%I#%&^w=0L%IL4p@#O8hn9~ zsJ?2CO#Ktk&5o{`dHZW@o$t@~`dnJ(Oku_OUJ3 z^Kgw&JuLqpUk(~JOFr5()w{n zk{-vYnLHTK$Y#A>eO{gexK{r?VXj^v+G=2$h#Bkclzma@e2JNzfYe}0_B3NdtU@Rc znT47Z#eHn5*KMY*qY3{s4d15K`(ms0bWxN;m#5jZY?j7P@rnk+#WKwJs-aqy9CKdD z3ZnQYQ1fPp=o?YJdG-8ArtRy^!)|bpNKvtjk_&>7W5+YAEFKP5%y>9eS0?$UwD=*9 zp_%G^89(FkZK|kGh^UPoNk_$z7p`4Fq z?hY?A*r_Sr`#aEGE$uqgIW;0J+dt;$5U^C*`5~-)GlEcbeYk<=ONuWh9?p(oK05A1HAU2=KsQn#TJ$--E}p4fkkj z{qV|#RQFp0#nCW2G%oKs0)wKxzsn&xgm^n%{)ujLr=@vo(m}59vwM z>A8Gg+V6x*zXc7U*0jUZtK4!FL!*-tFA}-`ghrQV#ln=UYqKUfzPH|{@e|N>Wlya$ zFtdgO0`NgMczT1&O^Uc`4ZVf6EC|JbaczoKsTH1MQEvsBjon~Iqy40w+P5i(7MXd^ zRc#iIHzYfJSe9Hsk7RDu56<;zT{-)q5~sIMBzq&dVnpP7A!oAFPuKb%rKbJFTcwyw zh?U@WwWB5eH4EOzG@7b`U~BQ5{6g)EW=N->TRq*w6gt?2kn6rnjf(XBGF@t1!Gz~4 zT3HXs` z^sGX4`64_|Wc2u}-NTKbRz=S@TGF95{IVH&TCwH-ujgBrLy*_f+~I^VpE1M+JkV|Z(_ z#9@oH{;O-ntVkbGQd)Eza_%RIKg~x32i6tYf7u0Kl*|MK{suh5B-Sy#YtuCk(yyG_ zfSXe=YB-%m53>TiA+&*x--d|+0F#$t)u{kM70R_5%nuRlhJ5T)%ERES>Md`d>!xJ9 z<`9ubs^9O0ij(8xhr;eWXe~436IIUc2v!EBLAQ@Q;}dB%`^_ zS%&CaCOcu9UR!1k#;+j)Pt6_8Rx4i)ilmPWy}2Le(1B#$U~cETqtqnO=U^~f@@2D? z4`O8cL(|+PrcUz)FQ=v=Y7?kb2G%#QzN7CkY+SOcH5$7mS{zz$7ZQ~r^8}f!xw8CM z*Lv9-lf3x~SEwNHl_b14i0U&2#BPaC2JM%RQp}N=DcmE{(Z-@9=Q6AcKj~s9((`w@ zRkvG>T=?71F|$MX^wu>iyzaDQ=4tkigfW_o9znei|2%UK#`fhI(oLu!!T&tN*7g+i zhp!DKwrRH3Gc(EBC7NzV|5UD5ee1>b1q@bpVM>zk9Iu~*QLccdNS{E_`gh(qt63;V zFEi=gZ0=+Gfd;fi-7sGDjk|R^HpTIYHK(l1yW$EzHmM3Q5z}aa(0tRtLVxgUDtl7p zLc!;gdlL8+W-!5K_QRm}vp3DQh-%^S(zdFLjH9=?Pf|pmm8F95yo6*sGxM*%EcQI! zR#>SLh!3o}H6HcNp{fFFWmH)%()OOxNXX1?YF8Pp~(}Lba?;fE#V!T^; z@j5KwyUjSP-jK;{nNz1d|CK%8|Js0!@FG{KF2iA9&-jUDA0M zXd3iZd=>g)^i%tb=R$$d(KwN!KdCX_D=T*k@o`9FA7PwN@TuwOic<2&cAI4lM!Qrs zQD-ATH|1FaVj`|;C;)LtdLr$+_pC=l%R8Hn&x7(a@J}_8)0qP=4AXP(_tuOA&+($?6-0l(%8mN2U@9jvAp>P^TBK`l zf8b4k&GF;1v%7O1d8PZ~OW?r#@lV+f9m-x~DRQ0higrZ+Noc|+?s!F;Fy~ZN8qM^s zc%7uIj*+u%p6#LW5p{P2?JnPc8~EZY{z@h@fw)$($WbS|@6WdW>oYW?M*$@+ef=Bt zP~8|tf9hOio-N45Zoq~HAj1h1Q#>CFWa@$hZ+$VF=SUjE)fUoAX`4@H-xT%wE{mq% zH3t6kRP8mg6^(OM8HhBde zW=~79Y_`zQCUFy+m@wwXC6R%BA>Bhk|B7{Pc6_V%l{C8OmpI$s+po8`>u0QIraz!i zc%1yNkHXqL%b`Ap&)Lk8yb>M{q=nvvW<_@W>2HGOY7;*kAwSF}tG2pKj5+8uN_NDH ze1LGt$G<6(yN}b%m6s^)Tvk=6DaMaCKj{bziAH+{*loW zW156|zhPQOoVNBNJVSr9o7X0wN9^fe53@a=1Ui7~pHG)6yU&er8$MLHE#cEF-*3${ zQc%U%+R$ub-r=n=8`A2Xk847C?G1+rSrVLg9&VQ*RRcSdfoS!Gf8=aUPQ2T(rp-4jO_rF4pUHG7@_mZYQ;p2Zq%ScdA2T^B@3YRlZJLw&LD)k(`foScXI{?->A%lSMIg zB2~;j6)r)IEPE*8kw76)|C|wWy&X>UICw{cF2fFgb})o=YMZ(5EkGD)yhZ7CG@_j* zh%*2YIF&q->gYvV30(Hh+%H+DL(zTz8$UnEQxE5E&H%LfoL>(DoT->IrHeWPqaLI_ zsFuEVt!kDn^RRxB^Ntz$p1t$n!}V$i^X8Y3m|9@^%+z&zPSMSke)%5}5=|MQlrsMV z0OXyVsud&y+@oMsXVE3A?G5=~uuFLtJ#YI?9`*>M&>ngKo z&V9~8`hvx4Dj2d(6;;z?HbodqFPr}lT)06L9c$z7JE99|pNnZ&R&xyt|D#jsy^bM~JO4Th! z+K69$Op@)7XNpv1_gjSxXt>EgfUYjb+^~QWR zb>}(B5qB{%?O2tF06W@g+k0Q=Q@ZSJI3}^4mf>=J*|&jUoi6hwV_&Ews|tvuFYmUs zwLf#PQO`vd$aAsOw4srQdq&&Up6Z~Nz#10G5I zN@$mDNyk?|y6!2vo)X#bKU0>rep{6BN|#nn@!R}ZTmd#l8s9p9>-|MtuH#C!^p?zs z{78a!8k)MUZR1d`h)xu_z`U&0bxp2seU|QNnvuvV&!N6EPFN$cZl zSirzjw6}>yZgt9ycmdE@z~}Ruvan#Ij9MKaM`IC+u`_MNpqdCrV=1H%n#w-jw;ZIk zqh;M`8?8}ScY;cPV0WojfG}PS6N)BJ=pIWCF#wCTGNSWxT zm&w zKbBRML$gN3q?4e$8J^Kk9_iiB5II*>{%LM&+ER)m(<X_FnDLpD=@`ZJrY^ysEZ+*9a3+@=`R340!HBnCKnpW=sRfmAE z_Ytsah=#zh5@w1Wh~`b>{;!%P>?qaks@jd!4WM6ig9==`&HVoH3(tK3o>f4d>A5c+ zfz`7)zZ^*e4#Y9&ym`gqyBl>sH)KVS9+$rU1`p^jt+l!Se*L=HMjdM4;zJ|v4)U;h zdSR54#T~pb+M_;Np0TdGQ4;A6G!RT}w{;sindNeO0#R^vSbfg6jB@fBHCP(=h}(*Lx6gvm$Xw7ur|QsySm>Elx>yBjv3_@E$#td9B2`w3rCqGl0;} z#_nSDO(~kEIPQ~MzK5Gh7#HNlI!^>c#oQ+T`XQ>m1?yCLHLL)A{jWO-%WH6hqM@Ri z$55x&sxxjw=Y?~#H$WwO4bfqjzoV@Kq>^m9e1X;nGxA!klu$s$fBx-^%|l4PiBl@Iz!f_<{m#jBAq2qm6&8SdKI z{PtJrNq6U~L~gIPqd0(TuvnQKu`XQGNCPVf!XoAPAF<-ZKUYPqss9g`=n~Y!;R!yAs&? ziPDyqOXfHu!2!_uHsQ{u(XU%-QtvkhXTTMH53nc)?8XrD2z(r1-Zm%_C)Xm+ipvkC z_wfLkLHiV627RVB{p}yPz(6)ZSKPnolBYOnV+mZFVGfFb1BLP!=iJ8~9_4)PQgVjh z$1imy7KY}bV>FfNLfh(aWeMvPj1rJ*yv}Jv{M1s=ZQt<4o|}vJm$?F(M6k|5%4czu zZJ`PD62b1tg`xHO^)-mTIsMN2%yDQeeMjcvfI^%sf%vbI6{OEfzw(6^-s8b{{smfC z56s)hnIbWDDpH-Ue#@Q{xFPJXgGIV2G5;NcE)Q^=aQgz<9@9|Yl5}eU;fSar&w1$W zm-3#Q@bIySh+Kv!q-FCHwgYN>sx2BI*A@K~ogqacB~E-{|6Mq8kdgQsHuO7XHbhYI z*3mWkS~(MgtEGro1cO5Q1v5eB$#m!$aY_s}+~+EP`|^?@z!lU~_Y5mOYf*@qD14O+ zSxnp)vRS{y`s`ph-JrYK>lc{lWj)Wo`-gkl?S2p2_`KWqP|Yi0qQ#$3HQOV51kAC5 zrO|F`U{-6?P-AHns_cc&Xns6u?n9;j(qY&DKSrc4uI1rpd~-OSsTs)Ee2TDv*^y~3 z&$7dlN;=W_9{R|aW^dKMdJyR?q>Mat#slw1{n)gD%+GD!mO_2mh#r}i&I=74ZzYVU z>WRzDAOKVj%h(x8EBv#izV~ucO;6Z}VdQi#nCLi4dy_x?)R&%)op%-Yl)5%7U0pFa z!4lV$!2aifsOlWjv=0Nh7+-v-^E`yC~my5Yv7aub@= zG*&3kw2HFoD>UE0UPtMFa%d#~=7krL0l*iR24Ml!g}2__;5lrPx?MGT^XSFHAHN@D zU#N4@$CDva4K=m}I;(Wl*Zzwm`%NKdq^x3R6ML)ElY2s2!^x;UW&4QK9rZOBOJBqTtB_&-?=*Eyo*Cvm51N+OO2Yt1h;1gANXCqJdrok9s7~q3!e&m)7+K;b^cp6u60?P`lK*=JLHt z(0{%5;Wiv29YCJzTxLS*SW&vc+ZgCW)#7MGuvCRVBNk{O6V>Aa-dxtFQQ=#m#;_Tk z&H1%o^oY*bhJ(;PLH>TZ^uzxU4K7^LXq^A%)4LulSGKopb=NE9;03O!R<<-^iuAbW zt<(WG{5V~|ykaUSWCN?DA!{X4o z7e>YY6_2+bZi`-3vRDyY(R5>dZYVDk+t7#%pVWR!Qd1tx6Z|9%NjV!&UtSef^Uq^u zZBlgT@Y3}LV+$iXs8=d{;CP40*q=l z(PO^}vdIr^Y7r<&HY^!c-4!Z5Y@|?s&oo{p>hwvo_y@J-veuZqwAdyjKIEG8QW{W1 zHRY3hgJR<2`i!VZIy3Xofmd8LmridO8^3~^(?UAtXASqur9m^pbO>PG`^U0Okm9=3 zUADE-RwshE#}p-$96eqdkbm0yOSJN@y4e{Bkk2+6UB$K7!A z`)4WR4=DcC%~g2^$=Hbv9S8Fz_y4EhGb;X3#3=^Exnr(%@JKWtcSG6pkc(bFcg_0X zKXW~anMqKWWUOZmUm_z@NeNXc9%^@_>44@n<@!^YNYF@Rts+{_H0E{^J%3cNT|JR4%O%JcX2wnYqTSMF! zZ=p5G$Y4$1Ouon=Fb^2U-NV?<{&0}P$D35H#)dSE9NcMe?j1EC6r%4Kbe5fFSSyb{ z%XpXKB)(ptNJ#giaT^ckckRipk5b&IhxRQ}E3pRt=GR2<#Tdb?H2r`CJiWe@06DA^6Kcnn9-lDNT*t{{Tsq!V~MiOzgQr5ZH2A$B9(*IPHG! zzn%F1w8?4mfr212uF&^n~~=delaqt8dgvvCYTL?UWk%VwQw_%Usk2T?=?Ou&)f zez#5C@qATIbwlf6(eFqSvB)C%fP}vr`bL%rw;eLXvb<>zg`0nr`&iuDIoC^`Q1C-9 zAbYK3U0;s_PQe1FzIdcH(Xdm=;Xa7OuO6QLXdr1dzsGG5;|8vEOE50+ECw(ov2H?s z*43I%F`WMTazN+RhkHPa(p~aF7c@Ml%AdWCg-0YB13~durc)&FXSi414j;dMUjnbr zjaLZ}bk=+P>Ht$TX;4bi*gY^XDx8v5H`FL=`LBL_F)135GaYxY&PbhyT#2GZj(psZ>HH@CM|S-!lsPm6^sDY`6{^+eI!OjqTS z`+YG`INAn{J7-~jDjW=e6~_kGV|k6Y)F5icG`z1UNo*x?GXOuHYf(lIoesP?;MoPw z+%o>{ZR9)wU>T=J$C%5C;7lfiJUZRB-fx52&_M!z1A4mQw4H*0Wxy{4A-*Kd6Zaso zv8lqsQp$I00&kP7DUL!A&YtYD`DhU!l9`n6rgq0S;McWu=n3GRT-~pjy7(v}w$dgm zjHN=ghJj2xxp34cBESCN5@JtBdvX}J4EwR&s;~=RB%D|-zAt6;wj$GP#OBXLMUo^^ z|74)ISbZh(#fi{v{GWe=UkUeTcRu?mDSFS2v>5(f+?*!e3e;IyU6Wm~0@TlXKCQuc zmgvVk7c_Efz9BW2J~s;^iWpmZD{MXoM};4pvyK;SvZPj;)Z=}(4tKg>bLRh;^WXhx z99E9{AArToX5mjSQru~vM^wwaN_;dX*Iqp+TgX~k@fZW7e$^WtB&t#f=vgUKGMEQ@ zqe^?jED7UxZcKbei&|1XioouD18)m(}fov)q1G8-BT%I z?Sii7ooBaIfAj1beMbm5SkaBT&$K=Rt9t!x)cpI^zAXz`8S`+CUN$mAv^U@t>>Zr0 z?ofQ&TV_FkpuzPi?5`e6;m_F+I%y@xctjuw&NKzCI#jMWWmP*5=9(P%eJ<>J$gES-Q+~2OR~(E9XE;a}3kc zSU!szO4zJ?2-uqE^r>?IlIAVzxb%$&&jquKx?9sC;b79ediiDwFhKeDt9jlfzjIk4 zhyvL2R-c&$U1-gAp-*vPU5vjM2vT66m>}_?>P=t zx{qRoS@YgXJx=z+GgoA%?SOo0b2&)f;~WlwUfcIP%OM+{orPgs-2C4K=;IH?1+7Wl z>M^P+D&o$`6I*DjXP9gNKinyhukTtnS7)g78nyn@WLAPuyKrQ5P>n_B>E=8-_#dG3 z^S9k68x=7UDrQ5An_AbDt0fh$Kz!z2s>CS#1C9dg6bl`kw;8#&W4JYiWr8{* z>i57=UL8s;ZP&&yU);4QY}L_8I#%=0lVL=D6Ey&j%Bbu%>0fR&5CLCb#j9J?Hut%n zeE-zTxz2kt|Ng;`l(|32ro_kYw`}Iebc*E(=0Jq@Gk+J|Gp-#|=TO)f9wCbU|EA`8 zh&;eo?1u`MO2%5xc>mp8jw&%uy-x9ZMyZlxDcImoXP(nf33rvI54(D!jFd6x5Gk;c z)tKzlZkx{eXVxcOej_RBc!S|8<&IYc9b0^KmZOlksq{?HmP){halT@}3?(d~7U@lS zc>I<0E9SB7`{GGSZ_cqYty;?%SncqMw!>-e@}10FreNSUM+^KJDRcV(S+xSpJ zfF#KLJ*xEQ{W$YhmVOx>pVx2lBACkVKK13`Z=#N)A$1P1@vU$)jEKc<@lqrpH<@A& zrCm$lHz<`^ZrP9F*!+8=2dkB6M21(iaAa^ zgx9Z$Uo~Wl{#!P$;~DI&2%>oyY%Q3&Ye78Dt?>r4^V2S#&^VVwHJRU(*FXM*YYf@N zFdV`b3P9a!m8at}Ho?+G0#meCYv9h`HpjeYz=GXW!EQD-9c!avXr5t;(Uk;}Zv)7E z2H4%Gfsqg0KT5Ecgrh#x25@LbvhgoDQjqs9wfm_AO~fRPtr-E4Ody-;lu0}w;Fe`4 zJWfPpX-r329#Jb?(`LU(N9DN1&AsL3#A^3rgW}T{KEHVd^IAqgo(>2wr9*UnuV;P# zi7!Cj$db0zWhhlW9iePYAk-evq&|&qsP;OX^(=Umb6zbhG^5bX5bLl3mA+;~SL z5gr%%wJ9=do=av*K1;K!f37#e_u%YNNA_Vs!j{$8)lJhOr?g}6_}PM#hA+ML=d{xw zJhNlZ2e!~3jcvL`p54`IOrPcxK-ZR-{*-AAd|wc{>@%7}AY@4;gIQC&V47tH^tL6L zH@1QQZ%n|3vu+A9box>G^MMIj1%(%?kk&>8De~S&P~LT}MrDkamCPJta2nAC530jM zibRL?L2-yME@=0tf3$u|xXNH;V2$LlVA+*>_7GCD@b^9)jb_$hI`=-q?+A&0 zZAlLsIDkKzn2dudbs-{o_WOt{Ugf%g0T++GG^D}N&cZpKZ zAZ1Zq9M`Dsak`t|Y-&o@0G}LSx@T`4LTS7yzoQRGH)qFC^8GD;?9$P*YW?RoZ2`X}{2)#RThu|F#oH54qPw!c#gQ=;XVhVZhD=@KH-h*geJP1BOEwUKA> zhu=q*V`+!+$=`O1wA20A$raYd1%#2Dx8#b`87-GPwwkZlD{NhEVa%N8@e^e#=EMx~ zjfyp>-p?X1Y3I-x-SCQ?SYq#@5$m;d%(d%y)53g#N|&-34?P0V{p3@~7|f2FUT5odYSR*FWX;f* zQfomw7R@&=E)1RY^j^b7p{3p9PPpD?D+xv11L#UsNf_Ivf7uYt@2qx@JpEY9jMcc! zK-8=A1U-7&iuraXN=Eo`7!y4Uoeirj1)$0QHz;8|03cirAu|gA_Wl-!XQse|Reas# zBaot#as<)ZdQ7OQ;Cly8ueRXtyL3u%%MbVr^3o~wGIV-p z&>_hQNP4tlcTRG5Dm~gnFOGrUN38YsbB1-$RUX7UtYqM4FHIBtR41K-qZ(Vm)I9?rtoRw#lK>F9Xo2Op zN4nH<(ttY=%u`qz50g=*rdJX1CM7ev6uRlFH2@ShXdqLAmC?%1E8Yk3X!vPoo2;vu zduFm~yvXF<7H9lyR9JeJT-L!7DaoX?&O-i0f*U|PU_939r-~S2zTO?e2%g@IPFD56 z^5VbVU$`kTq&~3D%vVEVq~A3{Zu6V!l3l;N{~>4wyt1d_Y8F_!juW^3lH;+xjhL<$ z%<=|PzG&a(=?3to;#GeO>4rb@w9gVaN+SzL9skJW=-3WP_b&v4NGodlLP;bY^2SzR z-w5M(I29Km7Ht{yk=y;>kkWT{=}5KC%BdX)aS2Q1+?#aMAQi~3qjGZJ_h&iQ!ixJm zbPC&1-s`G#Iz#QmGWSK;k_vP8rRD(DPn&V`iP6@1eQIypIDX0?_g1{K)VFN+ zv^Mg9VO-eGPt)7dj(+bregaM2h3P#Jpv1(DWq9 z<}E-@7Qlw_b4B-Y07x~DK>b9@!K8_Po}OXlltmHlQS>fl%Hvsv)V3%967Fp&`!^6& zz3{sAF6(qRviCPl?FPQhO8lQ2Ma4CVTo(0t8$^@q&P2w_@>@m^g0aqDc<2}ls=xN5 z5HEb-Z2Wnj^A6#vpgW&K0D{64o{R%Mwb&jmV-0$SlQ*fmLLG;mBRl4dYf=x~OZ*i6 z(=&Q>50;jPl~eJoR;3DC`abf(OJF;=jij)syB?Ko2TPl5eO^w)vuDr!4alFOa1%9+ z9ARXC-c0iUvEaeUF10H$t|T#0+);08wNkZjJc53Pgi-uD*BnBZAh^Z@G*43tlZY66 z;pZ<5qKwQB!uiVzK;u}~$^lB1bdCTMYbM|)x! zILn{uT&lkf@1gwn>?~cE3|0a&&HLY2)VVj!BxmD?|pWgG>2dGuY}e) z1`yTgHXT(>eb|ueLAFs|MsoaTqa=9&xOu>x+vzB%f{Z`gCxR&T<%?8nkZvC~wyh~J zyEq=7_~n}}7#qMD7D|)Hpw;L(r$;lsKlg0s`PwzS@pXHMU5iU zdgw$OI5)cm*!Ba!w?lQKNyLs?SN0A{nPH_#GU?JV zZcbmu0H#}h_urFsR)22Z`#6nHU#@FTja-#{!W;V)={!(Bm)S4$B@3ahy{S@eWt_}O zFQ*3N6|n%Ji?sVVb(sie6|oRt){ca9;hf69Z+GCqLzg1~Ko0bUO@k&I*UJQ+=NIMn zj3Tw}54X5;mjd)M^m|#@41FAmm|2*l&7KSn_^a>N;E(;j4WLP3 znBoSy1|5i8IGimAz;T64`~6CxnGQwNHF1z8%5k(i6}im+t1sqcCnCW*VKDv4GApq5 z)*YIZAiL@8xOgk;1bvEEgH5hDYB}_~DF7K19(=2fSuie8re*MY2vKO#1v^kv0}EBJ z4X|JnUAfq0{Xz&AePB0hBBpJYps5pjo3L{LAW_3$gEZgh4DI%WCRjXbiXFHpsTm*% zSf1P!9V%LSnZR0u=W=T=XQ*Kl!-4|(2|E^+Poe-NN;VicuC*C7pv;hl?u_H{yV<8t zORekG7xzlbn3VzT03KQV0@LF)4&@+$omu%-->To2Nk|KRf2ew&4As+XKy88q0CoTV z4g%1_dQaHI%-cTc1V+iFHXPmtqt`-FweAxdZjUexJ91x*&wfySQ$E@hL1Nip^mMp9 z1UPTN5Ox;5ri4nB>T5t620@MNBMtH!uj(dFa?$QMVtu_qR~&%Nx?wD~cv_1H&dUSm zQH3Yil=#g?ytqR~#e3zl=h{uJw{;f+7l9$RKjr%XpxmfZzE8~nB0rAVpE>*AEa$1w2oY?eSpD=+C2 z97!qSl7R#GtdIr@1K!*DdJ!bbYF2jX4B%EF#LmKPV6q?>AUu9xiU^{rD%T7bsBY(L zfW|#YMqDSf`F^==CMeZQNTN;X`0_oz<|UOpOylxAVl{yvTFe4O-PBzb$XH#VLbQ&# zlM940;ansA$u_9Ft!Im3u|iDH8Z>E3gYv5&NmF{k@9}2PBN3~=sd{a5;^&l8^g@#s z7uEY29O4&UvgBE{&Jjjs+u(XbxG0$^^zS|^Y-owPl1U>DATwQIHwXWUqs+{JvS~BMwmjUTMH9-%haRU|8M8#kxZz69B8|z+F;qJTb zGwUev-H~pgNdXi*v!yKZ-~YmJq1&cOu?*yPWX7-e^P*lmFV~eQ?y>dploQvxxIs%@U<{{9h-nUrU#Z> z>OPnfS}P9vigUNHWS+)#;y48j+~(kOcV-2=C+9=m&6#z^_v)?_DsQwZrJ~;Rl+uqK zO6fU2k)%4ywU!-=Iel}uQSNcCBGZFCo#EKe`?DyX# zOr~D_gpP)@VAl)s$zqb#&gL2j#)S(KT<=v2gg%WCXR~QF180sah)HbB@h zD>0|I*6DQ3X-0ruWwrTktZTO9TZn#~K-;qo!uilMzQm%^MjFm`0Kod{zpDSF@yvos zN`~7t&;Jb~k4G{vqW@F7T_+T(tG@aRLvvls^iBaJU!;7dc^Xf$;qgh7MRJESY+5K! z2&y)Apw!aXvTIE;gt?B)ebIp>R+7MW2=lR-Y~&CN{~UW8b=^K!;3wAoHv9V_3A8xL#W4I>L z_kb&~`@xb1No#D?n7SOU7ZNSbmVe%Ska6WXh{=5Uvv&|7t<&5;2!n5@@<8&#TS-15bvHN#8V=d996ZjU zCmaXGCISv8+O-v0rmyl*R`JsH&Y?J{GH6$I!yeF3@;y~zpbiI?gs$sdfN;9qx%yB# zKKwcLC#x4FqUSEQ3POg;nUGmin#L5HhphWgD*~GB57ezM;U8@7J$d65Vkm|kP5BQP zs;bs9G&FzZ%hT9T{#L?3@PcRR#yi zQQr`icxB9=X%9EUGSma=_>zGmNu|ho6uj-hheFk39i`!%cLCbvZj&;m^2Q#Wr zmQwX?dPbg^wF zcyF3&qo5BNta6Ip8}vPf)pxK>XV@Ix8xf^7FpBlh%}WFJC)KN^TW1#1_;+OMXSnh~ zv6gg`>G=tqU}&6>jaN*>zYWFo|4${zHwL9Fmmjp@dtp{INjXLyD2>$-gh;xC>JHkx zOT2>Wi4#Dz{;ACjiyj)$%2r75Tpl(OOB&Gj6RC9{pS^IBO`d;m`L+9L4NtP;= z<%WLd9PMg#-r(UC3mX9x%Pimd@UTP|)t_@G3`JfnDOkHv56&de13x%-<=QK+taRo% z!Sk+7%D)G%F`drz4&mgY$*K-^W20Thz~gH^aPbUPO*5GJ+p|LY_gMAsMcXfQ@ffaW zxj_5U`eP{PRZ^m&iNZ6_-rhyno1FNdiOK|RV_?7WMe|wR`;P*g=+YW@eBWw6WnJF! z+ZbHy39|Ypfww1~%87L@vhe%BqK72T>=?}*tq!S$^LZi6nn+UT-cH=NJ!G5pQ*rbl z>`P0nh*0AB+)91v2JuG9rR1!qkM|05C7HrS{#jDHOi_EHYGtu? zqJ-9#3j?&{E-b}#`I$y6IVLe_jay(=sk_DHE594Ac~QJ@>*|q5dw26foIQLb<{Cl& z0Z%JYXTN%$ezO-Q{j?#w-||Y@eLw4*V{>55r*j>|VqZTZ$5^6g`SQi!I7d6ek?joR z@-m=t?6b1RD{*$sb3$VT4&(gTnUvZUo#s-H3F~2+b0=lvV7!;Z z&6^p_%O^~Uo)Y6j-lJV=NMUXr!x?3iNXH2Wy_1L@-g@$$`$%S;`|Lp0fH8BmzK(Eg zf=*I1_<7EVUvISg67i7=d5)A@7VFlyeB*02N&0G=7rYz7hzf_>EWkifE&z#6YhBre z?2IsCS4Yqe3ESrU>nK93C);}o7NmEn`PVQ%_~_@p6J&^Seka;o(O~>-Vj! z&RvQAwK$2YFRgVYT|~BwR8b#pstvAhkJQ&1Mt}!)&{sto#f1zp_}HcYfKO_bdl6Zo z!72uhrYD!5#ot)h&RYR<_c`|&oDcruNp_k}tvz^nzguwCKofl;GI028bI67QGsO8D zW@@+c|a?f3z@ zH}%`Z6oMuR{pM;xIC}A)d)Zb&Br)es%zf<`l|P}qPAW1a!=pBXt)A0WFMi}}yJK@x zX5NYMlAnYvPZ#_(Gt=UZb6yi?E7`;;ZX4xT*k4gUCaymAjx~i=;~c=Kq+q71OU=L1 zy=9n^qqzbW^w)I^*{J|&c3bK=y}e{%D-m(zohU=LNZo;1Z}i~%3~CE;0z2%CX(JN( zieIvmTt)fxX~N@*Y55c~8p$+6)D^ZVIW(w40N>*)qC7&j@v$be7vwxZW8HCh`8{O0 zE|t8!m)(lz&I}yRkN=(mSAp@}DY-PmMT!bHlk+K~>$&0b?>5_V`3810CB}X}x?d$z zSLmhTE#+ektJgLbAvk$idVqAA;<@|2jtk3H|Jpa zV)~9FbvSyL*IDqWf4&+#e^~1HRs4ac)_(U!i5bHDg!qi7WIV6$Ea1-X(axKGA)6jU zZ$U@?1HfnZ9jXLzqzhU!SR$j`{)=qkiF(tn8w^u5Ojo$=N4KDI;^;7YdHyk7dt|7a0Hm?OKi^T1mwC zd$!Z*fZD4YHiFe<^vb1`JD3&{7LED^8{*U2^ z^JiT-Kf*yj=sRu|p!zh02QXXj2Dcr^sflmYd|*~JezYP0hEq3+{*-u;9TB@;YA1d5 z-u0fXm)zn5ziWz;3rhvXZqehTb<-<$JzLNXSC)HeD?wbIT7~!;98TSSi+g#jLi;2$ zKsNp_^@~n>1No8xoGfd_pB+WXCpy>f-k}fkQ51e3cqA!QSc)hY@iU9=p@te@v$Kg$ zOl#ax&{zyPMgiTk3CMT7iWxFL__e}={`eNrph+lD&@#UWx=96|;2dx4v%!d?; zfQCv}n~-oX>}weg8c~fylT_H7xJy0s_^;%O9vSoY_m7M+so$=e#Gfubtu;xq+^3i! zPl0!6=s&*BqhcAKie;TW-UN~Vl%2j(C+l31D(uE3tmw*Hz`qA4?MiQK8HOLq!O)+3 zH?ZnCX6-oE)i(j1MFxaqAl+n_&)bdg=z)?;uba^5r+>NyhcQK&bli1jdN|ER>dE}P3`p}X^iQXAr2?t z?R%Y!S)s7Nb$QHR^3#pg!$ZZr`1o(6d5&AjF&(CLNC<2}>gP~)aB@Pv4WZ+f)B0V@ zzbEQW5~!VY4nBIqs7@S+KO&)?@N9NuiZvT|EslH(z942U9<_w5%*08<(#tAt`k_mO zV>5kua_loiQKZnFPdAwC>LL`6MlaYps}$uh1FJenuXYT2R;|H{ZK_@$bn;xnN1WAa znrr1Die1Q|N~uj%?uz!uO|;z=-!^%ls=r=kiYF4WG?l@~dg_{0bnt=cYbc!Q9IM_} z*VzpDt?m5-GQXz_RXAnNR1$o?YZG)w8Pk=ncKMStO*U0_W5J;&dKC2rkDEG$|JY)` z?|24!XyVy-+oQKko@xIKg=@=fn#w`F_nD5rmh*xq4i`T)jU=Bwj_^S`Gzzq=0`ZS+?b?y8gdBk#qwVmTkGn?A0} za&u^KV4aPY>1>C-Yj(*5p#|Uqsv0WG=@p2#Lh~~r7<+4Y!>+Wa#!+3dz7CN;8{9oT zC56k1Di){Pxg!i>iNzOE%r~k=m$ZIYDAzeH+~$81opGCnsC#wj!t7I4T{~QU2d{3~9*W4dDoooO%&%#9#5ongy4Idqbg+Sp97wcHQ;ThE zy5{fi;p=ZaeY2sgu|v+eNQa64RT^F2eB1T~4&cB`6iF?lL^{Eny(1)!^J(CX<0rdC ziRNLQ!Z^>Vo0NlhP3}I@?{q(wfEtHeE$!Z`&eDy(F|#aYp-DcuEPX9&PD%ph7D(9S z&wVcnm6TzNv~ACaCWX!I@YCoLDE+98Arz#M8DM^0GaOPReUSOcJ|K$MIWLyE38`s+ zeQ$>sBQqAQp!FPKtjdaNPk|jPyI#DJCXpUBT}M2fH1nDXuc~m(Sby4-4I@Hs05^9( z=N?Fp#(Ul7fLN01*dhB2(uR!C&?fP*+o?+z z@xo#1k3374qs0YJtUrbYzaQb1U2H!&YLjwAf{e@#@h0X+s|1ME&RtT(JLH20pP<_7 zE6kEhYYt(K{PoS}Ft;vzteho}cB^ll^eAvK9{c>@uk`JGtN-;IDGetbKW^>g>}*(I zj;{f{Qfp1{2xWnYgrw<~0jo@jB6MSWc)iA*3NL5(TU_vo=(>e}9``m9px1@rMsrpf zjynKSH!<&{6hcKaRckt6)IyfUwzM>P-%n5qzu-Lb%VP3jLd>jwpTNbKzEO)V!(on$Ca&mn%F6*+uea@6_;H$Jfb z)2-t!3vJv0=Lh?tS2Mz6E{NYvPs*=xJ7AiB(=UR9cMOttXwT@@3Hv=*Z88~i3>>^1 z_fW+nhvdDEXM|nw&?~K@Pq?d*Pu9GKXWG7(vrsoT{bSCP^ygqs;;fgm!;@jZoGkeY zj3hS@x>Om?dVBk3x3xX^-;jD^>R?TA(nY*kiGiEzr5k6?9{Ud<6iV+^sSe+unx58+ zX7(yV8@=&`;^$AhJ9D1=Tl=U)8u*$1VCo5NDeTCs9sk&cpYt}fBhS|_TN=KukJaeb zodQjQ11RCR0-nhh1Vc>n5Scmu|mDdv+Dq1n_LM4GnUN|Dx zh_6SH(1T{TO2+*HJia~X;6!W(B`46tW$ zQ2m%q$;6haqvK4f_xbj;`5C-(LD4ruhW_E!lmwWfHul9}10eQ;|Jvm50dKIcc)dO- z1Clyo;=Skid6SrmS33MyydKek%4=DqNES5Mv@I(*Dl#|G^07wySODqE+Eyy4lk@DR zprNZDXSHbGJ&VHdq`O>r1Nl+Y=WpS~$4`1NYoMCgV5RfX+woazPgZXevp#L#p4uQl zK}Y}G)7!N3uKf?#mm+Je$s&rE6?P7hZZ%jE}H!Q!VsjpLfs*;+%}9=l=>tK|BF1_(gcES;TGJHQd%f?pDV^6GZV$Zofw5 z2mLWcUJ>w{)})}fHux!-@)EdoomLd)5C;fP^R;;GA!UsbkP3ha4 z|4wuCJ|o7uupf(>_4hJOCt}^c2aB1MPh2=L0%Jn@y&L8KSYLM63_7}DB=e*WnpBf} zdDU51`@IS?qb%3%ib+C1d48+K&?1HK>1>W9=9+0v%z?-7e_W|5=T~z2>HMBEv#uoG z@Oh%-7srDvvkZ`{!eeiA>q}sS7k-$&EPQovG3T)doj`Enh=?4y<`Swo=7+E24>B8j z%|ZJwNXy)0W)M=QUPxjD6dqg)3XOf9eI}@<#KLwXUo4wq{#DC0vZv#}nJ(KcnpMObFkQYrP5o?_%r(%AN}RfOz7{Zy#d%`435&XG$w zd4}p+R3V3V4*SdrACtDYbNO;En>fGFx)t3X<@|!Dgi%MCE>D=&`Jfr0HU#;gN0G%d zrPkM;JBBtJkw;~3)DE2vbNNY--0%)6($fsuRjI9LP6vJs8lK*A4VYl-Mn3fLP^&pd zBLjg*)QjZN0?&qN>oZXqHO%wb0CNI3Wa=xX4>6zyeB0+$fo%m2Ld?D;r68;Tox`EZDM>G*TLk+36{`!}y>uP+8;=^W zFw1_%#;+Vv_(~6RN9;HT513i!KcN3i6CD1@%xLfZU!jJ}@+KZb641oRJVvzCCpY_Nz5~m$gh&)J<=pzZ-LSZENP>=QxWy{F;;!Y2swipk zdd?`gQ#@D`8*4Out zrb*(NzvKDh{LRJp{huyQC#&MZh85~KftOj~;W7QknZ*G)0cYh4wfs!8!<)&3WTA!V z>x!Ad7f7x~Dz2_)M`ZO|LIMWjI^^Mf%SUDIYYxX9xK|2vZI&-t^rOZWEF?x1qC!rL zxYRLg!p?_NLDV0X-T8M4r+i5zM@ge`;fEzOhUT~K+3w0)NDKwoB>DBa>{#p77dc#U zG!2_SX6lFqKi!g)`wz(2RktqLshK(#Lc?5H}hG zHjsn;)OEZ`iOUV+ z{P4|{S?3PB?7SXU{{%t*lNINHQSkQUtIR{T>t3E4b6%xERnEAsmP&y;?3214!v6t{ zjg&Wu{lu=)ejfIF#8_A9^hwGMX;!wWtW6z`iiUHXx2)}@&-6BgW}p&3EV-T#a1@$Ykvp7|zhZ#_a;it!VY*JQaN8*%NQyC18y#+6&y@*O7!U>alN zB4!z#vmXApS(C4%)(sWC5ZQ5iMJ-+`HydL^rkbyJat)$Hv|=JS_N^FlRsU2i)?}v$U8l z!vjMJRsq}iEnd%nHz%AztJ3G>e%MlL-@LJSAof2%tit5^-Gox5IdC9lSe7j-`BykQ9=Y}_{-H8};-B26@x$z3A} zA@vy+-}56)@ySPKZAISY{np&_Rh?*^9ShR7?>i-)i*)5}MEL+*r*LXI`7NbFvgLO7 zETTtG%0nOX?AyFTD*9yO>3tJ=08&X_8`EqS;@T-sq11yUM#Dr`XU-wV30@J}od6Es z6M4_}J><$A5X!n!*w%reHg>CQ*FZhZEK1=iUf3;4yzP3AkcZ35!EK4Yv9kMYRP1uA z@ZHWB3G~L(+R1(&Yp+ATxs8Lw_o12JNtd29D5lREmh>1YCe;mF!dgn9lcpC&U-VuH zdVvXhuS#waN32Y=*&^f!cu`#PFu)!mtFFq0K^1fD!?C3zztPySGaeMf4?)~46k^ac z!+jnF8J3G)5^n|8LJ{?^^FW=4s|V$KaL?E{NG=T2Co-5Rf;5UUBNqT*oImimsRP5W zM*X4mBgJPb`5T={7E?Z!6wg;@1-)DWtH`)C)1?nS+Cvw3^VxVH8fZlw(HlxW>{)P| zD@`^NyVVSJC%9Gi4)DRuthge6D#a9)%=@q>6m_94lF$xhYipaYq1uy7n|6eCyTJ}1%C~`>(+-@%4+A0P~ zZHw8#hkgSl^;bUIFfm6K^g za&hAJq0#oJ;VgXfa$ zj^yLdAB6{_5BpykC~wtL?P*z)d7+bp!wbbEGevO{7Sw=nhl3jMo#t8{n8^iP!o6p)7{H>HyM?6TH1`(QRMBz#A*9Nnmw;3} zjK8S*MeI(q4F z8cscWJzF<9K8c0P7WNa;i9d&*JF_>4XrTY0DP!Y9tqHjqyW(bbypGL0>l%uc()9@5 zxfjMEl|YTFd?sqC?0Am|wPBu>)pVTd`Qf*i1x>;&T0-)Q-<4VKM-%3veDZmto;{&$ zr}2}9e!|%c#wIv-AD7{6X`6}?LANJoJgDt$`LJ8pJ?XA8UsO)Vn+#>fCtrolo{xHD zXEymo3mK=SU*+Kes)_?VKfs7TLg^)+3vU=JDG%$Jrsk^LYWk>}@MNR>b%=n^bSGWG zSQ5L5`5vvFg?FY`s#bpd2$Gul3)Ob5KpP-RM2+sZ`aq+ImKW7?QT`@JaDe(0dYW~A zuf?Mgr!n{~WuYel6e`wlD>wZU|B_rp+thcRylgo0R;hs9V+_y3%(M!q{I3)A3k%{$ zG=zz-y>!Yv6Jgp5`0=KL;{SkYNX@o05_HD}unrM@%r%(j2eE`KX5Xby$>4z!4Mkmb zG}Bpi4ak+_;*#Iaq-dT4A3zZULY}pdhoYS&6?Xxl4<*{Sz8M=1$eLcCZq~X$ZY|O0e>~0eVrEQ|IG9C^AQKtg>;VPzX#faf8c(~~wbdtAMi;SD^J#blJtfwO0n(ZB9CwFU@vqHt za$A;XZvAOQ_^GDAgW06TDjXsc2LX=a1SYEyz?h_BmW5yv1+SlS zeWZfZ@_b=fz&-L&W%eql5-ySsV0b2b!UFBp*>}LOpk=JZk|J`ax%zgDxMXKTYLS&V z#>JG=I+MndR2VSXHx-fE?HBTTu`FnubEaN}gO;8_w+d{FOW<6ng0tYjG-G-#qpxDH z<;4)xHQV}u)M7QOn3OG+e2xa%EThg8u{NwhS+pN#A1( zM1;$lvfMJOJxzcHdRtbONo;vYaPYk3mJa`EV0%7dBy^BTt`wLVUAPpu&PD=|FQI^P z<%v!WtR%BN(B6c41SfN&LOrdp>8xl`2y*97{d373u-78>d$$1pf?%6!kWPq?g;!qw z9tt5Ht0uyC-~X<`z)GLPLs(Vj`nQHclLAu$bKqN2Dm8_Qmrzic8!CfxwC7CkFFll# z>d}1Xx2UsT{#&KY37aRgb_l*i&%1L-4LUR=2MI2>YDnP~9<7u1lPKc51>lm7R6DwHJ8pSK3k)G;>l8PpZ$`(^ez zz=4Y`dnr_Z=M!~+)bVP#Ew6WQN`oiIIG9Ak{O|f0?~hY%7<3E+IYbN;A>KOxyjP;{ zLJ1h<)RjxE>kbzE!x3n3@D1}`A*3JWTc((+orhPda)T1mU z!z6erBC})SAZs+)dxMfoc8j(JNdVfW*^zE5X}=W~6;y7vpd1C&1=+J-)bS}Gf9bvu zN&r#qiv^dY5THo5;JG+(fQf$O!(=pH)i)6J*Pr!L6X7ZGUMzn_D=FhUooz;LwF?r8 zzfT-AJex!bG~~iXQlma+it|}>EhSk>0}WSy*n?oEB`4DpGlR;X?j{N-tYkg8R-Ko7 z^M~yaDEE%HL|-Qjma(UO%(?RD0 zhRsZznJ*kt$ zQOlA3TkQ#HlaI(McQeGO%($~cpl)JO*zWT>1)ZsS-HW$GJsEhF2Xz`vl4?9E334FT z;E93Q4D4~E1*FMp?S9K4z5B}tqY(#f2A8}L!|wL!MN z{lak(Pzs5~eB~Ns;{M0+?94j}1MV5sv0E`>@i#~@k*#BTGBD<=hE9Ypx}liFWLKIm zLHVIPDK(o{I%f=8UEv)K56sq-1*Al2h=O%O>yz_56@S?m8H5O$(NrfgLX8?O;?B`* zICntJ%4B_qbjL2Y+fAJ&MEiTUBW#CtQt<|CyBc%j9!txo{l;28ji)2oeJSj$;=$GD zVS;yYur6zBkWqQ5*>P1wBjc^O8duasYXseOuT)@ubIQ8CE7ovttzy)=97UNHLE)^K zHxTBxBtgz3Ni6X-csrw~m=AXGtK=?M2oPvZk9>Ebd1!Q8xvWkwE) zbV$x4BxknPS%kdhlsEvqS@t15Ro{?d#C``IbQcF-{Az58auyFW9A)$MYrRT#YEm<> zFS<5=Bx%V&|9n=rlGD{}tyn_wAC5zzO0&qDqYplv+MS*Ln`hx@k{md(s!18?ZAM?D zTlMInBsy}jgkHu&Z<|${uoT7I5>zNPD1ORZe?tYP@VQkYN$RI>`Ra4JiH}II25(OsO8V!v_yCgW9M?o^Y|b@VSi|X#Nf*wfOJ@V$xJNV-6Y*`{YBli!iRkty z!|q^6ubJohnNazs2DdwfF*XazJtj|Lw&eP?bM#Jy;Xku*Wg(5|O|92f^9Nnts+EfX znop@^6`m!H!;`9oF{-);=p`1^1S!1$D+uG@ydtVS%K|sI6QnHXt`ybj#(jtgL7tsc z4HyP5eNEDT?W+%!@AHE(8B@OL%c)}D_c0U%O+hTHuhRq$z7@r@zGqRzDqkk?{`=E; zBn&sw&L_6GV^%PuKWK_+EBQWDzdaFR zMw7%GCg^0k&&k4wJ1Of{!gz8_tplRBojJN5W^La%)-prKc@|S~Pw3ES=sDFZPBO$I zGuty~MGY-bKl*Kf!ONQ$cf%!HVNJPtN(zm>_Hrq__Uh4fV<}D3?AAGR_Jj^0g`yxY zm!wjyD^VRG-<^p>D6{TQ^p#>2#JzODID>}l`I8*|^deafoWXo&0&-nGZ4^DG^1oOH zbWU2#7&d4|$eLR)vlp050MKu3hNjlczy~$jFjQCApFc!tGb5CgDnZ7S zS|t(g_F?&YeJAx9Q;+LEO~-)nw+QxNmPAw#Ha}F5WOqA%;*Z&%T5oKN2<)JQ5WIWy zNFg>|m2b2yi;4m7WeY?9=*}$7RNE^IeYY`*O*Gzgbe`ryjiS`CG=8>bI0bJn9hxo- z5=SR4t_ntvcBKmgDj^8#zG0aCpK2@(EO;)~sQK{FK1f*^rjE@37{Z-FzlPv^wL|F= zGlG24-Ne|KiGerlpy3e7}x6O;zFj3$n5BoU>hL^3_^wIsnN| z5r1e*LguZ>*y-Zn48tI7CIyjaPkUv^$ePhRG0$5?wj*W#;raD{mKj`h{n)!OFD;kG zR06`XY+JNQ-A$oxB{ywhC|&WJF)_Epjd{k5vz?H8OxoKEKl=9Po#I;Q67)4wSw{oB zv^!IcP*Q}#W~5It(S2FV)Bvv&dQeU!(L`x`zk_t$NxxCUOKDd86qAuzN;90?l)E49 zUR$>9s+CBrd+m8BFjre?LpNvQbihQAI_$a4dv~{+9?Q&Qn@Z?<1k&&b&b_vkT`N;C zP#(*w``tEcZ)W{=X%KYnoRaSEs5%{Y?!w1$e^%uP)V-%8KHQzHpyRHMz2u(q2ydY* za)sc-xE|4Sdr)IrGI fTM)`aJ#bcC$7UNVHv9f_Tn7}k_o}gq|1JI>;z3O6 literal 0 HcmV?d00001 diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc new file mode 100644 index 00000000000..e830ce34842 --- /dev/null +++ b/tensorflow/examples/label_image/main.cc @@ -0,0 +1,295 @@ +// A minimal but useful C++ example showing how to load an Imagenet-style object +// recognition TensorFlow model, prepare input images for it, run them through +// the graph, and interpret the results. +// +// It's designed to have as few dependencies and be as clear as possible, so +// it's more verbose than it could be in production code. In particular, using +// auto for the types of a lot of the returned values from TensorFlow calls can +// remove a lot of boilerplate, but I find them explicit types useful in sample +// code to make it simple to look up the classes involved. +// +// To use it, compile and then run in a working directory with the +// learning/brain/tutorials/label_image/data/ folder below it, and you should +// see the top five labels for the example Lena image output. You can then +// customize it to use your own models or images by changing the file names at +// the top of the main() function. +// +// The googlenet_graph.pb file included by default is created from Inception. + +#include + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/default_device.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/command_line_flags.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/tensor.h" + +// These are all common classes it's handy to reference with no namespace. +using tensorflow::Tensor; +using tensorflow::Status; +using tensorflow::string; +using tensorflow::int32; + +// These are the command-line flags the program can understand. +// They define where the graph and input data is located, and what kind of +// input the model expects. If you train your own model, or use something +// other than GoogLeNet you'll need to update these. +TF_DEFINE_string(image, + "tensorflow/examples/label_image/data/grace_hopper.jpg", + "The image to classify (JPEG or PNG)."); +TF_DEFINE_string(graph, + "tensorflow/examples/label_image/data/googlenet_graph.pb", + "The location of the GraphDef file containing the protobuf" + " definition of the network."); +TF_DEFINE_string(labels, + "tensorflow/examples/label_image/data/googlenet_labels.txt", + "A text file containing the labels of all the categories, one" + " per line."); +TF_DEFINE_int32(input_width, 224, "Width of the image the network expects."); +TF_DEFINE_int32(input_height, 224, "Height of the image the network expects."); +TF_DEFINE_int32(input_mean, 117, "How much to subtract from input values."); +TF_DEFINE_int32(input_std, 1, "What to divide the input values by."); +TF_DEFINE_string(input_layer, "input", "The name of the input node."); +TF_DEFINE_string(output_layer, "softmax2", "The name of the output node."); +TF_DEFINE_bool(self_test, false, "Whether to run a sanity check on the results."); +TF_DEFINE_string(root_dir, "", "The directory at the root of the data files."); + +// Takes a file name, and loads a list of labels from it, one per line, and +// returns a vector of the strings. It pads with empty strings so the length +// of the result is a multiple of 16, because our model expects that. +Status ReadLabelsFile(string file_name, std::vector* result) { + std::ifstream file(file_name); + result->clear(); + string line; + while (std::getline(file, line)) { + result->push_back(line); + } + const int padding = 16; + while (result->size() % padding) { + result->emplace_back(); + } + return Status::OK(); +} + +// Given an image file name, read in the data, try to decode it as an image, +// resize it to the requested size, and then scale the values as desired. +Status ReadTensorFromImageFile(string file_name, const int input_height, + const int input_width, const float input_mean, + const float input_std, + std::vector* out_tensors) { + tensorflow::GraphDefBuilder b; + string input_name = "file_reader"; + string output_name = "normalized"; + tensorflow::Node* file_reader = + tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()), + b.opts().WithName(input_name)); + // Now try to figure out what kind of file it is and decode it. + const int wanted_channels = 3; + tensorflow::Node* image_reader; + if (tensorflow::StringPiece(file_name).ends_with(".png")) { + image_reader = tensorflow::ops::DecodePng( + file_reader, + b.opts().WithAttr("channels", wanted_channels).WithName("png_reader")); + } else { + // Assume if it's not a PNG then it must be a JPEG. + image_reader = tensorflow::ops::DecodeJpeg( + file_reader, + b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader")); + } + // Now cast the image data to float so we can do normal math on it. + tensorflow::Node* float_caster = tensorflow::ops::Cast( + image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster")); + // The convention for image ops in TensorFlow is that all images are expected + // to be in batches, so that they're four-dimensional arrays with indices of + // [batch, height, width, channel]. Because we only have a single image, we + // have to add a batch dimension of 1 to the start with ExpandDims(). + tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims( + float_caster, tensorflow::ops::Const(0, b.opts()), b.opts()); + // Bilinearly resize the image to fit the required dimensions. + tensorflow::Node* resized = tensorflow::ops::ResizeBilinear( + dims_expander, tensorflow::ops::Const({input_height, input_width}, + b.opts().WithName("size")), + b.opts()); + // Subtract the mean and divide by the scale. + tensorflow::ops::Div( + tensorflow::ops::Sub( + resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()), + tensorflow::ops::Const({input_std}, b.opts()), + b.opts().WithName(output_name)); + + // This runs the GraphDef network definition that we've just constructed, and + // returns the results in the output tensor. + tensorflow::GraphDef graph; + TF_RETURN_IF_ERROR(b.ToGraphDef(&graph)); + std::unique_ptr session( + tensorflow::NewSession(tensorflow::SessionOptions())); + TF_RETURN_IF_ERROR(session->Create(graph)); + TF_RETURN_IF_ERROR(session->Run({}, {output_name}, {}, out_tensors)); + return Status::OK(); +} + +// Reads a model graph definition from disk, and creates a session object you +// can use to run it. +Status LoadGraph(string graph_file_name, + std::unique_ptr* session) { + tensorflow::GraphDef graph_def; + Status load_graph_status = + ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def); + if (!load_graph_status.ok()) { + return tensorflow::errors::NotFound("Failed to load compute graph at '", + graph_file_name, "'"); + } + + session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); + Status session_create_status = (*session)->Create(graph_def); + if (!session_create_status.ok()) { + return session_create_status; + } + return Status::OK(); +} + +// Analyzes the output of the Inception graph to retrieve the highest scores and +// their positions in the tensor, which correspond to categories. +Status GetTopLabels(const std::vector& outputs, int how_many_labels, + Tensor* indices, Tensor* scores) { + tensorflow::GraphDefBuilder b; + string output_name = "top_k"; + tensorflow::ops::TopK(tensorflow::ops::Const(outputs[0], b.opts()), + how_many_labels, b.opts().WithName(output_name)); + // This runs the GraphDef network definition that we've just constructed, and + // returns the results in the output tensors. + tensorflow::GraphDef graph; + TF_RETURN_IF_ERROR(b.ToGraphDef(&graph)); + std::unique_ptr session( + tensorflow::NewSession(tensorflow::SessionOptions())); + TF_RETURN_IF_ERROR(session->Create(graph)); + // The TopK node returns two outputs, the scores and their original indices, + // so we have to append :0 and :1 to specify them both. + std::vector out_tensors; + TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"}, + {}, &out_tensors)); + *scores = out_tensors[0]; + *indices = out_tensors[1]; + return Status::OK(); +} + +// Given the output of a model run, and the name of a file containing the labels +// this prints out the top five highest-scoring values. +Status PrintTopLabels(const std::vector& outputs, + string labels_file_name) { + std::vector labels; + Status read_labels_status = ReadLabelsFile(labels_file_name, &labels); + if (!read_labels_status.ok()) { + LOG(ERROR) << read_labels_status; + return read_labels_status; + } + const int how_many_labels = 5; + Tensor indices; + Tensor scores; + TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores)); + tensorflow::TTypes::Flat scores_flat = scores.flat(); + tensorflow::TTypes::Flat indices_flat = indices.flat(); + for (int pos = 0; pos < how_many_labels; ++pos) { + const int label_index = indices_flat(pos); + const float score = scores_flat(pos); + LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score; + } + return Status::OK(); +} + +// This is a testing function that returns whether the top label index is the +// one that's expected. +Status CheckTopLabel(const std::vector& outputs, int expected, + bool* is_expected) { + *is_expected = false; + Tensor indices; + Tensor scores; + const int how_many_labels = 1; + TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores)); + tensorflow::TTypes::Flat indices_flat = indices.flat(); + if (indices_flat(0) != expected) { + LOG(ERROR) << "Expected label #" << expected << " but got #" + << indices_flat(0); + *is_expected = false; + } else { + *is_expected = true; + } + return Status::OK(); +} + +int main(int argc, char* argv[]) { + // We need to call this to set up global state for TensorFlow. + tensorflow::port::InitMain(argv[0], &argc, &argv); + Status s = tensorflow::ParseCommandLineFlags(&argc, argv); + if (!s.ok()) { + LOG(ERROR) << "Error parsing command line flags: " << s.ToString(); + return -1; + } + + // First we load and initialize the model. + std::unique_ptr session; + string graph_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_graph); + Status load_graph_status = LoadGraph(graph_path, &session); + if (!load_graph_status.ok()) { + LOG(ERROR) << load_graph_status; + return -1; + } + + // Get the image from disk as a float array of numbers, resized and normalized + // to the specifications the main graph expects. + std::vector resized_tensors; + string image_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_image); + Status read_tensor_status = ReadTensorFromImageFile( + image_path, FLAGS_input_height, FLAGS_input_width, FLAGS_input_mean, + FLAGS_input_std, &resized_tensors); + if (!read_tensor_status.ok()) { + LOG(ERROR) << read_tensor_status; + return -1; + } + const Tensor& resized_tensor = resized_tensors[0]; + + // Actually run the image through the model. + std::vector outputs; + Status run_status = session->Run({{FLAGS_input_layer, resized_tensor}}, + {FLAGS_output_layer}, {}, &outputs); + if (!run_status.ok()) { + LOG(ERROR) << "Running model failed: " << run_status; + return -1; + } + + // This is for automated testing to make sure we get the expected result with + // the default settings. We know that label 866 (military uniform) should be + // the top label for the Admiral Hopper image. + if (FLAGS_self_test) { + bool expected_matches; + Status check_status = CheckTopLabel(outputs, 866, &expected_matches); + if (!check_status.ok()) { + LOG(ERROR) << "Running check failed: " << check_status; + return -1; + } + if (!expected_matches) { + LOG(ERROR) << "Self-test failed!"; + return -1; + } + } + + // Do something interesting with the results we've generated. + Status print_status = PrintTopLabels(outputs, FLAGS_labels); + if (!print_status.ok()) { + LOG(ERROR) << "Running print failed: " << print_status; + return -1; + } + + return 0; +} diff --git a/tensorflow/g3doc/api_docs/index.md b/tensorflow/g3doc/api_docs/index.md index a1539d45a28..d6b1b86fcfe 100644 --- a/tensorflow/g3doc/api_docs/index.md +++ b/tensorflow/g3doc/api_docs/index.md @@ -14,3 +14,5 @@ Note: Many practical aspects of usage are covered in the Mechanics tab, and some additional documentation not specific to any particular language API is available in the Resources tab. +* [Python API](python/index.md) +* [C++ API](cc/index.md) diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 8cb1f96e09d..640ceb1bead 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -245,31 +245,54 @@ strided according to the `strides` argument. `strides = [1, 1, 1, 1]` applies the filter to a patch at every offset, `strides = [1, 2, 2, 1]` applies the filter to every other image patch in each dimension, etc. -Ignoring channels for the moment, the spatial semantics of the convolution ops -are as follows. If the 4-D `input` has shape +Ignoring channels for the moment, and assume that the the 4-D `input` has shape `[batch, in_height, in_width, ...]` and the 4-D `filter` has shape -`[filter_height, filter_width, ...]`, then +`[filter_height, filter_width, ...]`, then the spatial semantics of the +convolution ops are as follows: first, according to the padding scheme chosen +as `'SAME'` or `'VALID'`, the output size and the padding pixels are computed. +For the `'SAME'` padding, the output height and width are computed as: - shape(output) = [batch, - (in_height - filter_height + 1) / strides[1], - (in_width - filter_width + 1) / strides[2], - ...] + out_height = ceil(float(in_height) / float(strides[1])) + out_width = ceil(float(in_width) / float(stides[2])) + +and the padding on the top and left are computed as: + + pad_along_height = ((out_height - 1) * strides[1] + + filter_height - in_height) + pad_along_width = ((out_width - 1) * strides[2] + + filter_width - in_width) + pad_top = pad_along_height / 2 + pad_left = pad_along_width / 2 + +Note that the division by 2 means that there might be cases when the padding on +both sides (top vs bottom, right vs left) are off by one. In this case, the +bottom and right sides always get the one additional padded pixel. For example, +when `pad_along_height` is 5, we pad 2 pixels at the top and 3 pixels at the +bottom. Note that this is different from existing libraries such as cuDNN and +Caffe, which explicitly specify the number of padded pixels and always pad the +same number of pixels on both sides. + +For the `'VALID`' padding, the output height and width are computed as: + + out_height = ceil(float(in_height - filter_height + 1) / float(strides[1])) + out_width = ceil(float(in_width - filter_width + 1) / float(stides[2])) + +and the padding values are always zero. The output is then computed as output[b, i, j, :] = - sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] * + sum_{di, dj} input[b, strides[1] * i + di - pad_top, + strides[2] * j + dj - pad_left, ...] * filter[di, dj, ...] +where any value outside the original input image region are considered zero ( +i.e. we pad zero values around the border of the image). + Since `input` is 4-D, each `input[b, i, j, :]` is a vector. For `conv2d`, these vectors are multiplied by the `filter[di, dj, :, :]` matrices to produce new vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]` is multiplied by a vector `filter[di, dj, k]`, and all the vectors are concatenated. -In the formula for `shape(output)`, the rounding direction depends on padding: - -* `padding = 'SAME'`: Round down (only full size windows are considered). -* `padding = 'VALID'`: Round up (partial windows are included). - - - - ### `tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)` @@ -412,14 +435,8 @@ In detail, the output is output[i] = reduce(value[strides * i:strides * i + ksize]) -for each tuple of indices `i`. The output shape is - - shape(output) = (shape(value) - ksize + 1) / strides - -where the rounding direction depends on padding: - -* `padding = 'SAME'`: Round down (only full size windows are considered). -* `padding = 'VALID'`: Round up (partial windows are included). +where the indices also take into consideration the padding values. Please refer +to the `Convolution` section for details about the padding calculation. - - - diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md index f21c97b7c62..95f52dafcbc 100644 --- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md +++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md @@ -136,7 +136,7 @@ that the evidence for a class \\(i\\) given an input \\(x\\) is: $$\text{evidence}_i = \sum_j W_{i,~ j} x_j + b_i$$ -where \\(W\_i\\) is the weights and \\(b\_i\\) is the bias for class \\(i\\), +where \\(W_i\\) is the weights and \\(b_i\\) is the bias for class \\(i\\), and \\(j\\) is an index for summing over the pixels in our input image \\(x\\). We then convert the evidence tallies into our predicted probabilities \\(y\\) using the "softmax" function: diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5c6e08ae44a..e6cccbd715e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -839,6 +839,7 @@ cpu_only_kernel_test_list = glob([ "kernel_tests/save_restore_ops_test.py", "kernel_tests/segment_reduction_ops_test.py", "kernel_tests/sparse_concat_op_test.py", + "kernel_tests/sparse_matmul_op_test.py", "kernel_tests/sparse_reorder_op_test.py", "kernel_tests/sparse_to_dense_op_test.py", "kernel_tests/sparsemask_op_test.py", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 38e01798ac1..637def9b3cc 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -333,6 +333,9 @@ class BaseSession(SessionInterface): # Check session. if self._closed: raise RuntimeError('Attempted to use a closed Session.') + if self.graph.version == 0: + raise RuntimeError('The Session graph is empty. Add operations to the ' + 'graph before calling run().') # Validate and process fetches. is_list_fetch = isinstance(fetches, (list, tuple)) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 99bf356cc23..3af737801db 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -445,6 +445,12 @@ class SessionTest(test_util.TensorFlowTestCase): sess.close() t.join() + def testUseEmptyGraph(self): + with session.Session() as sess: + with self.assertRaisesWithPredicateMatch( + RuntimeError, lambda e: 'The Session graph is empty.' in str(e)): + sess.run([]) + def testNotEntered(self): # pylint: disable=protected-access self.assertEqual(ops._default_session_stack.get_default(), None) diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 0e813896ff2..4d6a9c1a58f 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -214,7 +214,7 @@ import_array(); "but got %s" % type(config)) status = TF_NewStatus() config_str = config.SerializeToString() - _TF_SetConfig(opts, config_str, len(config_str), status) + _TF_SetConfig(opts, config_str, status) if TF_GetCode(status) != 0: raise ValueError(TF_Message(status)) return opts diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 00369816ea8..25befa6dc32 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -167,7 +167,14 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): - raise TypeError('graph_def must be a GraphDef proto.') + # `graph_def` could be a dynamically-created message, so try a duck-typed + # approach + try: + old_graph_def = graph_def + graph_def = graph_pb2.GraphDef() + graph_def.MergeFrom(old_graph_def) + except TypeError: + raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e5a8f8a4dc1..1dc98d42988 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2458,6 +2458,7 @@ class Graph(object): control_ops = [] current = self._current_control_dependencies() for c in control_inputs: + c = self.as_graph_element(c) if isinstance(c, Tensor): c = c.op elif not isinstance(c, Operation): diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 8656228edbd..3ade3ee228e 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -632,6 +632,20 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): # e should be dominated by c. self.assertEqual(e.op.control_inputs, []) + def testBasicWithConversion(self): + g = ops.Graph() + a = _apply_op(g, "const", [], [types.float32]) + + class ConvertibleObj(object): + + def _as_graph_element(self): + return a + + with g.control_dependencies([ConvertibleObj()]): + c = _apply_op(g, "const", [], [types.float32]) + + self.assertEqual(c.op.control_inputs, [a.op]) + def testNested(self): g = ops.Graph() a_1 = _apply_op(g, "const", [], [types.float32]) diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 4d88809c88e..3d323100074 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -183,6 +183,13 @@ class Conv2DTest(tf.test.TestCase): stride=1, padding="VALID", expected=expected_output) + def testConv2DEmpty(self): + expected_output = [] + self._VerifyValues(tensor_in_sizes=[0, 2, 3, 3], + filter_in_sizes=[1, 1, 3, 3], + stride=1, padding="VALID", + expected=expected_output) + def testConv2D2x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] @@ -1008,4 +1015,5 @@ if __name__ == "__main__": setattr(Conv2DTest, "testInceptionBackFilter_" + str(index), GetInceptionBackFilterTest(input_size_, filter_size_, output_size_, stride_, padding_)) + tf.test.main() diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py index 14d657c8059..040d9eb3bde 100644 --- a/tensorflow/python/kernel_tests/listdiff_op_test.py +++ b/tensorflow/python/kernel_tests/listdiff_op_test.py @@ -10,57 +10,56 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +_TYPES = [tf.int32, tf.int64, tf.float32, tf.float64, tf.string] + class ListDiffTest(tf.test.TestCase): - def _testListDiff(self, x, y, out, idx, dtype=np.int32): - x = np.array(x, dtype=dtype) - y = np.array(y, dtype=dtype) - out = np.array(out, dtype=dtype) - idx = np.array(idx, dtype=dtype) + def _testListDiff(self, x, y, out, idx): + for dtype in _TYPES: + if dtype == tf.string: + x = [str(a) for a in x] + y = [str(a) for a in y] + out = [str(a) for a in out] - with self.test_session() as sess: - x_tensor = tf.convert_to_tensor(x) - y_tensor = tf.convert_to_tensor(y) - out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + with self.test_session() as sess: + x_tensor = tf.convert_to_tensor(x, dtype=dtype) + y_tensor = tf.convert_to_tensor(y, dtype=dtype) + out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) - self.assertAllEqual(tf_out, out) - self.assertAllEqual(tf_idx, idx) - self.assertEqual(1, out_tensor.get_shape().ndims) - self.assertEqual(1, idx_tensor.get_shape().ndims) + self.assertAllEqual(tf_out, out) + self.assertAllEqual(tf_idx, idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) def testBasic1(self): x = [1, 2, 3, 4] y = [1, 2] out = [3, 4] idx = [2, 3] - for t in [np.int32, np.int64, np.float, np.double]: - self._testListDiff(x, y, out, idx, dtype=t) + self._testListDiff(x, y, out, idx) def testBasic2(self): x = [1, 2, 3, 4] y = [2] out = [1, 3, 4] idx = [0, 2, 3] - for t in [np.int32, np.int64, np.float, np.double]: - self._testListDiff(x, y, out, idx, dtype=t) + self._testListDiff(x, y, out, idx) def testBasic3(self): x = [1, 4, 3, 2] y = [4, 2] out = [1, 3] idx = [0, 2] - for t in [np.int32, np.int64, np.float, np.double]: - self._testListDiff(x, y, out, idx, dtype=t) + self._testListDiff(x, y, out, idx) def testDuplicates(self): x = [1, 2, 4, 3, 2, 3, 3, 1] y = [4, 2] out = [1, 3, 3, 3, 1] idx = [0, 3, 5, 6, 7] - for t in [np.int32, np.int64, np.float, np.double]: - self._testListDiff(x, y, out, idx, dtype=t) + self._testListDiff(x, y, out, idx) def testRandom(self): num_random_tests = 10 @@ -78,38 +77,37 @@ class ListDiffTest(tf.test.TestCase): else: out = [] idx = [] - for t in [np.int32, np.int64, np.float, np.double]: - self._testListDiff(x, y, out, idx, dtype=t) + self._testListDiff(list(x), list(y), out, idx) - def testInt32FullyOverlapping(self): + def testFullyOverlapping(self): x = [1, 2, 3, 4] y = [1, 2, 3, 4] out = [] idx = [] self._testListDiff(x, y, out, idx) - def testInt32NonOverlapping(self): + def testNonOverlapping(self): x = [1, 2, 3, 4] y = [5, 6] out = x idx = np.arange(len(x)) self._testListDiff(x, y, out, idx) - def testInt32EmptyX(self): + def testEmptyX(self): x = [] y = [1, 2] out = [] idx = [] self._testListDiff(x, y, out, idx) - def testInt32EmptyY(self): + def testEmptyY(self): x = [1, 2, 3, 4] y = [] out = x idx = np.arange(len(x)) self._testListDiff(x, y, out, idx) - def testInt32EmptyXY(self): + def testEmptyXY(self): x = [] y = [] out = [] diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py index 8c1cda5e15a..f1a379bfdb2 100644 --- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py @@ -30,7 +30,7 @@ class InverseOpTest(tf.test.TestCase): self.assertAllClose(np_ans, out) self.assertShapeEqual(y, tf_ans) - def testBasic(self): + def testNonsymmetric(self): # 2x2 matrices matrix1 = np.array([[1., 2.], [3., 4.]]) matrix2 = np.array([[1., 3.], [3., 5.]]) @@ -42,6 +42,18 @@ class InverseOpTest(tf.test.TestCase): matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1]) self._verifyInverse(matrix_batch) + def testSymmetricPositiveDefinite(self): + # 2x2 matrices + matrix1 = np.array([[2., 1.], [1., 2.]]) + matrix2 = np.array([[3., -1.], [-1., 3.]]) + self._verifyInverse(matrix1) + self._verifyInverse(matrix2) + # A multidimensional batch of 2x2 matrices + matrix_batch = np.concatenate([np.expand_dims(matrix1, 0), np.expand_dims( + matrix2, 0)]) + matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1]) + self._verifyInverse(matrix_batch) + def testNonSquareMatrix(self): # When the inverse of a non-square matrix is attempted we should return # an error @@ -58,22 +70,10 @@ class InverseOpTest(tf.test.TestCase): # The input should be invertible. with self.test_session(): with self.assertRaisesOpError("Input is not invertible."): - # All rows of the matrix below add to zero + # All rows of the matrix below add to zero. tensor3 = tf.constant([[1., 0., -1.], [-1., 1., 0.], [0., -1., 1.]]) tf.matrix_inverse(tensor3).eval() - with self.test_session(): - with self.assertRaisesOpError("Input is not invertible."): - # Determinant of the matrix below is zero - tensor3 = tf.constant([[1., 1.], [1., 1.]]) - tf.matrix_inverse(tensor3).eval() - - with self.test_session(): - with self.assertRaisesOpError("Input is not invertible."): - # Determinant of the matrix below is zero - tensor3 = tf.constant([[np.inf, 1.], [1., 1.]]) - tf.matrix_inverse(tensor3).eval() - def testEmpty(self): self._verifyInverse(np.empty([0, 2, 2])) self._verifyInverse(np.empty([2, 0, 0])) diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index e682787f5b4..ada45a1c107 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -175,7 +175,8 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): with ops.op_scope(t_list + [clip_norm], name, "clip_by_global_norm") as name: # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm scale = clip_norm * math_ops.minimum( - 1.0 / use_norm, constant_op.constant(1.0 / clip_norm)) + 1.0 / use_norm, + constant_op.constant(1.0 / clip_norm, dtype=use_norm.dtype)) values = [ ops.convert_to_tensor( diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py index 8bab14c186f..01ace83d4c4 100644 --- a/tensorflow/python/ops/common_shapes.py +++ b/tensorflow/python/ops/common_shapes.py @@ -89,8 +89,8 @@ def bias_add_shape(op): return [output_shape] -def _Get2DOutputSize(input_height, input_width, filter_height, filter_width, - row_stride, col_stride, padding_type): +def get2d_conv_output_size(input_height, input_width, filter_height, + filter_width, row_stride, col_stride, padding_type): """Returns the number of rows and columns in a convolution/pooling output.""" input_height = tensor_shape.as_dimension(input_height) input_width = tensor_shape.as_dimension(input_width) @@ -184,7 +184,7 @@ def conv2d_shape(op): # in the kernel implementation. stride = stride_r padding = op.get_attr("padding") - out_rows, out_cols = _Get2DOutputSize( + out_rows, out_cols = get2d_conv_output_size( in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding) return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] @@ -246,7 +246,7 @@ def separable_conv2d_shape(op): # in the kernel implementation. stride = stride_r padding = op.get_attr("padding") - out_rows, out_cols = _Get2DOutputSize( + out_rows, out_cols = get2d_conv_output_size( in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding) return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] @@ -294,7 +294,7 @@ def avg_pool_shape(op): # in the kernel implementation. padding = op.get_attr("padding") - out_rows, out_cols = _Get2DOutputSize( + out_rows, out_cols = get2d_conv_output_size( in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding) return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth])] @@ -346,7 +346,7 @@ def max_pool_shape(op): # in the kernel implementation. if ksize_d == 1: padding = op.get_attr("padding") - out_rows, out_cols = _Get2DOutputSize( + out_rows, out_cols = get2d_conv_output_size( in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding) return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth])] else: diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py index c4337268bfa..e14d571453b 100644 --- a/tensorflow/python/ops/logging_ops.py +++ b/tensorflow/python/ops/logging_ops.py @@ -42,7 +42,8 @@ def Print(input_, data, message=None, first_n=None, summarize=None, message: A string, prefix of the error message. first_n: Only log `first_n` number of times. Negative numbers log always; this is the default. - summarize: Only print this many entries of each tensor. + summarize: Only print this many entries of each tensor. If None, then a + maximum of 3 elements are printed per input tensor. name: A name for the operation (optional). Returns: diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 5a5c06f975e..925ae76b98c 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -38,31 +38,54 @@ strided according to the `strides` argument. `strides = [1, 1, 1, 1]` applies the filter to a patch at every offset, `strides = [1, 2, 2, 1]` applies the filter to every other image patch in each dimension, etc. -Ignoring channels for the moment, the spatial semantics of the convolution ops -are as follows. If the 4-D `input` has shape +Ignoring channels for the moment, and assume that the the 4-D `input` has shape `[batch, in_height, in_width, ...]` and the 4-D `filter` has shape -`[filter_height, filter_width, ...]`, then +`[filter_height, filter_width, ...]`, then the spatial semantics of the +convolution ops are as follows: first, according to the padding scheme chosen +as `'SAME'` or `'VALID'`, the output size and the padding pixels are computed. +For the `'SAME'` padding, the output height and width are computed as: - shape(output) = [batch, - (in_height - filter_height + 1) / strides[1], - (in_width - filter_width + 1) / strides[2], - ...] + out_height = ceil(float(in_height) / float(strides[1])) + out_width = ceil(float(in_width) / float(stides[2])) + +and the padding on the top and left are computed as: + + pad_along_height = ((out_height - 1) * strides[1] + + filter_height - in_height) + pad_along_width = ((out_width - 1) * strides[2] + + filter_width - in_width) + pad_top = pad_along_height / 2 + pad_left = pad_along_width / 2 + +Note that the division by 2 means that there might be cases when the padding on +both sides (top vs bottom, right vs left) are off by one. In this case, the +bottom and right sides always get the one additional padded pixel. For example, +when `pad_along_height` is 5, we pad 2 pixels at the top and 3 pixels at the +bottom. Note that this is different from existing libraries such as cuDNN and +Caffe, which explicitly specify the number of padded pixels and always pad the +same number of pixels on both sides. + +For the `'VALID`' padding, the output height and width are computed as: + + out_height = ceil(float(in_height - filter_height + 1) / float(strides[1])) + out_width = ceil(float(in_width - filter_width + 1) / float(stides[2])) + +and the padding values are always zero. The output is then computed as output[b, i, j, :] = - sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] * + sum_{di, dj} input[b, strides[1] * i + di - pad_top, + strides[2] * j + dj - pad_left, ...] * filter[di, dj, ...] +where any value outside the original input image region are considered zero ( +i.e. we pad zero values around the border of the image). + Since `input` is 4-D, each `input[b, i, j, :]` is a vector. For `conv2d`, these vectors are multiplied by the `filter[di, dj, :, :]` matrices to produce new vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]` is multiplied by a vector `filter[di, dj, k]`, and all the vectors are concatenated. -In the formula for `shape(output)`, the rounding direction depends on padding: - -* `padding = 'SAME'`: Round down (only full size windows are considered). -* `padding = 'VALID'`: Round up (partial windows are included). - @@conv2d @@depthwise_conv2d @@separable_conv2d @@ -79,14 +102,8 @@ In detail, the output is output[i] = reduce(value[strides * i:strides * i + ksize]) -for each tuple of indices `i`. The output shape is - - shape(output) = (shape(value) - ksize + 1) / strides - -where the rounding direction depends on padding: - -* `padding = 'SAME'`: Round down (only full size windows are considered). -* `padding = 'VALID'`: Round up (partial windows are included). +where the indices also take into consideration the padding values. Please refer +to the `Convolution` section for details about the padding calculation. @@avg_pool @@max_pool diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py index 4c4963c1b9e..2712234157d 100644 --- a/tensorflow/python/summary/event_accumulator.py +++ b/tensorflow/python/summary/event_accumulator.py @@ -143,6 +143,8 @@ class EventAccumulator(object): self._is_autoupdating = False self._activated = False self._compression_bps = compression_bps + self.most_recent_step = -1 + self.most_recent_wall_time = -1 def Reload(self): """Loads all events added since the last call to `Reload`. @@ -156,6 +158,31 @@ class EventAccumulator(object): self._activated = True with self._generator_mutex: for event in self._generator.Load(): + ## Check if the event happened after a crash + if event.step < self.most_recent_step: + + ## Keep data in reservoirs that has a step less than event.step + _NotExpired = lambda x: x.step < event.step + num_expired_scalars = self._scalars.FilterItems(_NotExpired) + num_expired_histograms = self._histograms.FilterItems(_NotExpired) + num_expired_compressed_histograms = self._compressed_histograms.FilterItems( + _NotExpired) + num_expired_images = self._images.FilterItems(_NotExpired) + + purge_msg = ( + 'Detected out of order event.step likely caused by a Tensorflow ' + 'restart. Purging expired events from Tensorboard display ' + 'between the previous step: {} (timestamp: {}) and current step:' + ' {} (timestamp: {}). Removing {} scalars, {} histograms, {} ' + 'compressed histograms, and {} images.').format( + self.most_recent_step, self.most_recent_wall_time, event.step, + event.wall_time, num_expired_scalars, num_expired_histograms, + num_expired_compressed_histograms, num_expired_images) + logging.warn(purge_msg) + else: + self.most_recent_step = event.step + self.most_recent_wall_time = event.wall_time + ## Process the event if event.HasField('graph_def'): if self._graph is not None: logging.warn(('Found more than one graph event per run.' diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py index a28906c71b8..90646d9bd12 100644 --- a/tensorflow/python/summary/event_accumulator_test.py +++ b/tensorflow/python/summary/event_accumulator_test.py @@ -102,8 +102,8 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): def testTags(self): gen = _EventGenerator() - gen.AddScalar('sv1') - gen.AddScalar('sv2') + gen.AddScalar('s1') + gen.AddScalar('s2') gen.AddHistogram('hst1') gen.AddHistogram('hst2') gen.AddImage('im1') @@ -113,7 +113,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): self.assertTagsEqual( acc.Tags(), { ea.IMAGES: ['im1', 'im2'], - ea.SCALARS: ['sv1', 'sv2'], + ea.SCALARS: ['s1', 's2'], ea.HISTOGRAMS: ['hst1', 'hst2'], ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], ea.GRAPH: False}) @@ -123,8 +123,8 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): acc = ea.EventAccumulator(gen) acc.Reload() self.assertEqual(acc.Tags(), self.empty) - gen.AddScalar('sv1') - gen.AddScalar('sv2') + gen.AddScalar('s1') + gen.AddScalar('s2') gen.AddHistogram('hst1') gen.AddHistogram('hst2') gen.AddImage('im1') @@ -133,7 +133,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): acc.Reload() self.assertTagsEqual(acc.Tags(), { ea.IMAGES: ['im1', 'im2'], - ea.SCALARS: ['sv1', 'sv2'], + ea.SCALARS: ['s1', 's2'], ea.HISTOGRAMS: ['hst1', 'hst2'], ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], ea.GRAPH: False}) @@ -141,13 +141,13 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): def testScalars(self): gen = _EventGenerator() acc = ea.EventAccumulator(gen) - sv1 = ea.ScalarEvent(wall_time=1, step=10, value=32) - sv2 = ea.ScalarEvent(wall_time=2, step=12, value=64) - gen.AddScalar('sv1', wall_time=1, step=10, value=32) - gen.AddScalar('sv2', wall_time=2, step=12, value=64) + s1 = ea.ScalarEvent(wall_time=1, step=10, value=32) + s2 = ea.ScalarEvent(wall_time=2, step=12, value=64) + gen.AddScalar('s1', wall_time=1, step=10, value=32) + gen.AddScalar('s2', wall_time=2, step=12, value=64) acc.Reload() - self.assertEqual(acc.Scalars('sv1'), [sv1]) - self.assertEqual(acc.Scalars('sv2'), [sv2]) + self.assertEqual(acc.Scalars('s1'), [s1]) + self.assertEqual(acc.Scalars('s2'), [s2]) def testHistograms(self): gen = _EventGenerator() @@ -311,7 +311,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): with self.assertRaises(RuntimeError): acc.Tags() with self.assertRaises(RuntimeError): - acc.Scalars('sv1') + acc.Scalars('s1') acc.Reload() self.assertTrue(acc._activated) acc._activated = False @@ -321,17 +321,17 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): acc = ea.EventAccumulator(gen) acc.Reload() with self.assertRaises(KeyError): - acc.Scalars('sv1') + acc.Scalars('s1') with self.assertRaises(KeyError): acc.Scalars('hst1') with self.assertRaises(KeyError): acc.Scalars('im1') with self.assertRaises(KeyError): - acc.Histograms('sv1') + acc.Histograms('s1') with self.assertRaises(KeyError): acc.Histograms('im1') with self.assertRaises(KeyError): - acc.Images('sv1') + acc.Images('s1') with self.assertRaises(KeyError): acc.Images('hst1') @@ -339,21 +339,43 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): """Tests that non-value events in the generator don't cause early exits.""" gen = _EventGenerator() acc = ea.EventAccumulator(gen) - gen.AddScalar('sv1', wall_time=1, step=10, value=20) + gen.AddScalar('s1', wall_time=1, step=10, value=20) gen.AddEvent(tf.Event( - wall_time=2, step=20, file_version='notsv2')) - gen.AddScalar('sv3', wall_time=3, step=100, value=1) + wall_time=2, step=20, file_version='nots2')) + gen.AddScalar('s3', wall_time=3, step=100, value=1) gen.AddHistogram('hst1') gen.AddImage('im1') acc.Reload() self.assertTagsEqual(acc.Tags(), { ea.IMAGES: ['im1'], - ea.SCALARS: ['sv1', 'sv3'], + ea.SCALARS: ['s1', 's3'], ea.HISTOGRAMS: ['hst1'], ea.COMPRESSED_HISTOGRAMS: ['hst1'], ea.GRAPH: False}) + def testExpiredDataDiscardedAfterRestart(self): + """Tests that events are discarded after a restart is detected. + + If a step value is observed to be lower than what was previously seen, + this should force a discard of all previous items that are outdated. + """ + gen = _EventGenerator() + acc = ea.EventAccumulator(gen) + gen.AddScalar('s1', wall_time=1, step=100, value=20) + gen.AddScalar('s1', wall_time=1, step=200, value=20) + gen.AddScalar('s1', wall_time=1, step=300, value=20) + acc.Reload() + ## Check that number of items are what they should be + self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300]) + + gen.AddScalar('s1', wall_time=1, step=101, value=20) + gen.AddScalar('s1', wall_time=1, step=201, value=20) + gen.AddScalar('s1', wall_time=1, step=301, value=20) + acc.Reload() + ## Check that we have discarded 200 and 300 + self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301]) + class RealisticEventAccumulatorTest(EventAccumulatorTest): diff --git a/tensorflow/python/summary/event_multiplexer.py b/tensorflow/python/summary/event_multiplexer.py index 9070d75e3ca..d7abea0ab0c 100644 --- a/tensorflow/python/summary/event_multiplexer.py +++ b/tensorflow/python/summary/event_multiplexer.py @@ -128,13 +128,15 @@ class EventMultiplexer(object): return self def AddRunsFromDirectory(self, path, name=None): - """Load runs from a directory, assuming each subdirectory is a run. + """Load runs from a directory; recursively walks subdirectories. If path doesn't exist, no-op. This ensures that it is safe to call `AddRunsFromDirectory` multiple times, even before the directory is made. - If the directory contains TensorFlow event files, it is itself treated as a - run. + If path is a directory, load event files in the directory (if any exist) and + recursively call AddRunsFromDirectory on any subdirectories. This mean you + can call AddRunsFromDirectory at the root of a tree of event logs and + TensorBoard will load them all. If the `EventMultiplexer` is already loaded or autoupdating, this will cause the newly created accumulators to also `Reload()` or `AutoUpdate()`. @@ -156,25 +158,16 @@ class EventMultiplexer(object): if not gfile.Exists(path): return # Maybe it hasn't been created yet, fail silently to retry later if not gfile.IsDirectory(path): - raise ValueError('Path exists and is not a directory, %s' % path) - paths = gfile.ListDirectory(path) - is_directory = lambda x: gfile.IsDirectory(os.path.join(path, x)) - subdirectories = filter(is_directory, paths) - for s in subdirectories: - if name: - subname = '/'.join([name, s]) - else: - subname = s - self.AddRun(os.path.join(path, s), subname) + raise ValueError('AddRunsFromDirectory: path exists and is not a ' + 'directory, %s' % path) + + for (subdir, _, files) in os.walk(path): + if list(filter(event_accumulator.IsTensorFlowEventsFile, files)): + logging.info('Adding events from directory %s', subdir) + rpath = os.path.relpath(subdir, path) + subname = os.path.join(name, rpath) if name else rpath + self.AddRun(subdir, name=subname) - if list(filter(event_accumulator.IsTensorFlowEventsFile, paths)): - directory_name = os.path.split(path)[1] - logging.info('Directory %s has event files; loading', directory_name) - if name: - dname = name - else: - dname = directory_name - self.AddRun(path, dname) return self def Reload(self): diff --git a/tensorflow/python/summary/event_multiplexer_test.py b/tensorflow/python/summary/event_multiplexer_test.py index 4b6b29f66d8..5410b145657 100644 --- a/tensorflow/python/summary/event_multiplexer_test.py +++ b/tensorflow/python/summary/event_multiplexer_test.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import os +import os.path import tensorflow.python.platform @@ -13,6 +14,20 @@ from tensorflow.python.summary import event_accumulator from tensorflow.python.summary import event_multiplexer +def _AddEvents(path): + if not gfile.IsDirectory(path): + gfile.MakeDirs(path) + fpath = os.path.join(path, 'hypothetical.tfevents.out') + with gfile.GFile(fpath, 'w'): + return fpath + + +def _CreateCleanDirectory(path): + if gfile.IsDirectory(path): + gfile.DeleteRecursively(path) + gfile.MkDir(path) + + class _FakeAccumulator(object): def __init__(self, path): @@ -122,34 +137,33 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase): x.AddRunsFromDirectory(fakedir) self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect') - if gfile.IsDirectory(realdir): - gfile.DeleteRecursively(realdir) - gfile.MkDir(realdir) + _CreateCleanDirectory(realdir) x.AddRunsFromDirectory(realdir) self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect') path1 = join(realdir, 'path1') gfile.MkDir(path1) x.AddRunsFromDirectory(realdir) - self.assertEqual(sorted(x.Runs().keys()), ['path1'], 'loaded run: path1') + self.assertEqual(x.Runs(), {}, 'creating empty subdirectory had no effect') + + _AddEvents(path1) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ['path1'], 'loaded run: path1') loader1 = x._GetAccumulator('path1') self.assertEqual(loader1._path, path1, 'has the correct path') path2 = join(realdir, 'path2') - gfile.MkDir(path2) + _AddEvents(path2) x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(sorted(x.Runs().keys()), ['path1', 'path2']) + self.assertItemsEqual(x.Runs(), ['path1', 'path2']) self.assertEqual(x._GetAccumulator('path1'), loader1, 'loader1 not regenerated') - loader2 = x._GetAccumulator('path2') path2_2 = join(path2, 'path2') - gfile.MkDir(path2_2) - x.AddRunsFromDirectory(path2) - self.assertItemsEqual(sorted(x.Runs().keys()), ['path1', 'path2']) - self.assertNotEqual(loader2, x._GetAccumulator('path2'), - 'loader2 regenerated') - self.assertEqual(x._GetAccumulator('path2')._path, path2_2, + _AddEvents(path2_2) + x.AddRunsFromDirectory(realdir) + self.assertItemsEqual(x.Runs(), ['path1', 'path2', 'path2/path2']) + self.assertEqual(x._GetAccumulator('path2/path2')._path, path2_2, 'loader2 path correct') def testAddRunsFromDirectoryThatContainsEvents(self): @@ -158,21 +172,18 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase): join = os.path.join realdir = join(tmpdir, 'event_containing_directory') - if gfile.IsDirectory(realdir): - gfile.DeleteRecursively(realdir) - gfile.MkDir(realdir) + _CreateCleanDirectory(realdir) self.assertEqual(x.Runs(), {}) - with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'): - pass + _AddEvents(realdir) x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['event_containing_directory']) + self.assertItemsEqual(x.Runs(), ['.']) subdir = join(realdir, 'subdir') - gfile.MkDir(subdir) + _AddEvents(subdir) x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['event_containing_directory', 'subdir']) + self.assertItemsEqual(x.Runs(), ['.', 'subdir']) def testAddRunsFromDirectoryWithRunNames(self): x = event_multiplexer.EventMultiplexer() @@ -180,30 +191,45 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase): join = os.path.join realdir = join(tmpdir, 'event_containing_directory') - if gfile.IsDirectory(realdir): - gfile.DeleteRecursively(realdir) - gfile.MkDir(realdir) + _CreateCleanDirectory(realdir) self.assertEqual(x.Runs(), {}) - with gfile.GFile(join(realdir, 'hypothetical.tfevents.out'), 'w'): - pass + _AddEvents(realdir) x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo']) + self.assertItemsEqual(x.Runs(), ['foo/.']) subdir = join(realdir, 'subdir') - gfile.MkDir(subdir) + _AddEvents(subdir) x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo', 'foo/subdir']) + self.assertItemsEqual(x.Runs(), ['foo/.', 'foo/subdir']) + + def testAddRunsFromDirectoryWalksTree(self): + x = event_multiplexer.EventMultiplexer() + tmpdir = self.get_temp_dir() + join = os.path.join + realdir = join(tmpdir, 'event_containing_directory') + + _CreateCleanDirectory(realdir) + _AddEvents(realdir) + sub = join(realdir, 'subdirectory') + sub1 = join(sub, '1') + sub2 = join(sub, '2') + sub1_1 = join(sub1, '1') + _AddEvents(sub1) + _AddEvents(sub2) + _AddEvents(sub1_1) + x.AddRunsFromDirectory(realdir) + + self.assertItemsEqual(x.Runs(), ['.', + 'subdirectory/1', 'subdirectory/2', + 'subdirectory/1/1']) def testAddRunsFromDirectoryThrowsException(self): x = event_multiplexer.EventMultiplexer() tmpdir = self.get_temp_dir() - filepath = os.path.join(tmpdir, 'bad_file') - with gfile.GFile(filepath, 'w'): - pass - + filepath = _AddEvents(tmpdir) with self.assertRaises(ValueError): x.AddRunsFromDirectory(filepath) diff --git a/tensorflow/python/summary/impl/reservoir.py b/tensorflow/python/summary/impl/reservoir.py index 6af510a814a..8d322f7f441 100644 --- a/tensorflow/python/summary/impl/reservoir.py +++ b/tensorflow/python/summary/impl/reservoir.py @@ -77,7 +77,7 @@ class Reservoir(object): key: The key for which we are finding associated items. Raises: - KeyError: If the key is not ofund in the reservoir. + KeyError: If the key is not found in the reservoir. Returns: [list, of, items] associated with that key. @@ -102,6 +102,19 @@ class Reservoir(object): bucket = self._buckets[key] bucket.AddItem(item) + def FilterItems(self, filterFn): + """Filter items within a Reservoir, using a filtering function. + + Args: + filterFn: A function that returns True for the items to be kept. + + Returns: + The number of items removed. + """ + with self._mutex: + return sum(bucket.FilterItems(filterFn) + for bucket in self._buckets.values()) + class _ReservoirBucket(object): """A container for items from a stream, that implements reservoir sampling. @@ -128,7 +141,7 @@ class _ReservoirBucket(object): # AddItem are thread-safe self._mutex = threading.Lock() self._max_size = _max_size - self._count = 0 + self._num_items_seen = 0 if _random is not None: self._random = _random else: @@ -139,13 +152,13 @@ class _ReservoirBucket(object): The new item is guaranteed to be added to the bucket, and to be the last element in the bucket. If the bucket has reached capacity, then an old item - will be replaced. With probability (_max_size/_count) a random item in the - bucket will be popped out and the new item will be appended to the end. With - probability (1 - _max_size/_count) the last item in the bucket will be - replaced. + will be replaced. With probability (_max_size/_num_items_seen) a random item + in the bucket will be popped out and the new item will be appended + to the end. With probability (1 - _max_size/_num_items_seen) + the last item in the bucket will be replaced. - Since the O(n) replacements occur with O(1/_count) liklihood, the amortized - runtime is O(1). + Since the O(n) replacements occur with O(1/_num_items_seen) likelihood, + the amortized runtime is O(1). Args: item: The item to add to the bucket. @@ -154,13 +167,43 @@ class _ReservoirBucket(object): if len(self.items) < self._max_size or self._max_size == 0: self.items.append(item) else: - r = self._random.randint(0, self._count) + r = self._random.randint(0, self._num_items_seen) if r < self._max_size: self.items.pop(r) self.items.append(item) else: self.items[-1] = item - self._count += 1 + self._num_items_seen += 1 + + def FilterItems(self, filterFn): + """Filter items in a ReservoirBucket, using a filtering function. + + Filtering items from the reservoir bucket must update the + internal state variable self._num_items_seen, which is used for determining + the rate of replacement in reservoir sampling. Ideally, self._num_items_seen + would contain the exact number of items that have ever seen by the + ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not + have access to all items seen -- it only has access to the subset of items + that have survived sampling (self.items). Therefore, we estimate + self._num_items_seen by scaling it by the same ratio as the ratio of items + not removed from self.items. + + Args: + filterFn: A function that returns True for items to be kept. + + Returns: + The number of items removed from the bucket. + """ + with self._mutex: + size_before = len(self.items) + self.items = filter(filterFn, self.items) + size_diff = size_before - len(self.items) + + # Estimate a correction the the number of items seen + prop_remaining = len(self.items) / float( + size_before) if size_before > 0 else 0 + self._num_items_seen = int(round(self._num_items_seen * prop_remaining)) + return size_diff def Items(self): """Get all the items in the bucket.""" diff --git a/tensorflow/python/summary/impl/reservoir_test.py b/tensorflow/python/summary/impl/reservoir_test.py index 6bc5a7ed76f..d2325689b7b 100644 --- a/tensorflow/python/summary/impl/reservoir_test.py +++ b/tensorflow/python/summary/impl/reservoir_test.py @@ -90,12 +90,14 @@ class ReservoirBucketTest(googletest.TestCase): for i in xrange(100): b.AddItem(i) self.assertEqual(b.Items(), list(xrange(100))) + self.assertEqual(b._num_items_seen, 100) def testDoesntOverfill(self): b = reservoir._ReservoirBucket(10) for i in xrange(1000): b.AddItem(i) self.assertEqual(len(b.Items()), 10) + self.assertEqual(b._num_items_seen, 1000) def testMaintainsOrder(self): b = reservoir._ReservoirBucket(100) @@ -119,12 +121,14 @@ class ReservoirBucketTest(googletest.TestCase): for i in xrange(20): b.AddItem(i) self.assertEqual(b.Items(), [i]) + self.assertEqual(b._num_items_seen, 20) def testSizeZeroBucket(self): b = reservoir._ReservoirBucket(0) for i in xrange(20): b.AddItem(i) self.assertEqual(b.Items(), list(range(i + 1))) + self.assertEqual(b._num_items_seen, 20) def testSizeRequirement(self): with self.assertRaises(ValueError): @@ -132,6 +136,29 @@ class ReservoirBucketTest(googletest.TestCase): with self.assertRaises(ValueError): reservoir._ReservoirBucket(10.3) + def testRemovesItems(self): + b = reservoir._ReservoirBucket(100) + for i in xrange(10): + b.AddItem(i) + self.assertEqual(len(b.Items()), 10) + self.assertEqual(b._num_items_seen, 10) + self.assertEqual(b.FilterItems(lambda x: x <= 7), 2) + self.assertEqual(len(b.Items()), 8) + self.assertEqual(b._num_items_seen, 8) + + def testRemovesItemsWhenItemsAreReplaced(self): + b = reservoir._ReservoirBucket(100) + for i in xrange(10000): + b.AddItem(i) + self.assertEqual(b._num_items_seen, 10000) + + # Remove items + num_removed = b.FilterItems(lambda x: x <= 7) + self.assertGreater(num_removed, 92) + self.assertEqual([], [item for item in b.Items() if item > 7]) + self.assertEqual(b._num_items_seen, + int(round(10000 * (1 - float(num_removed) / 100)))) + class ReservoirBucketStatisticalDistributionTest(googletest.TestCase): diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 3f9371fead4..a712075d077 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -445,7 +445,8 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, dtypes=dtypes, shapes=shapes) _enqueue(queue, tensor_list, num_threads, enqueue_many) - full = (math_ops.cast(queue.size() - min_after_dequeue, types.float32) * + full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), + types.float32) * (1. / (capacity - min_after_dequeue))) # Note that name contains a '/' at the end so we intentionally do not place # a '/' after %s below. @@ -513,7 +514,8 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity, capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, dtypes=dtypes, shapes=shapes) _enqueue_join(queue, tensor_list_list, enqueue_many) - full = (math_ops.cast(queue.size() - min_after_dequeue, types.float32) * + full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), + types.float32) * (1. / (capacity - min_after_dequeue))) # Note that name contains a '/' at the end so we intentionally do not place # a '/' after %s below. diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts index 0b0103d6971..4ea7fdf83c2 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts +++ b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts @@ -4,7 +4,7 @@ module TF { type TFDatum = [number, number, number]; type tooltipMap = {[run: string]: string}; - type TooltipUpdater = (tooltipMap, xValue, closestRun) => void; + export type TooltipUpdater = (tooltipMap, xValue, closestRun) => void; let Y_TOOLTIP_FORMATTER_PRECISION = 4; let STEP_AXIS_FORMATTER_PRECISION = 4; diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts index 64e537154b6..ac3a5211630 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts @@ -39,13 +39,13 @@ export class SlimGraph { } } -interface NormalizedInput { +export interface NormalizedInput { name: string; hasNumberPart: boolean; isControlDependency: boolean; } -interface BuildParams { +export interface BuildParams { enableEmbedding: boolean; inEmbeddingTypes: string[]; outEmbeddingTypes: string[]; @@ -352,7 +352,7 @@ export function joinStatsInfoWithGraph(graph: SlimGraph, /** * Execution stats for the node. */ -class NodeStats { +export class NodeStats { constructor(totalBytes: number, totalMicros: number, outputSize: number[][]) { this.totalBytes = totalBytes; this.totalMicros = totalMicros; diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts index 36692dd0f04..17f7a0b4053 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts @@ -11,7 +11,7 @@ const LOG_PREFIX_MSG = "Graph hierarchy: "; /** * Class used as output for getPredecessors and getSuccessors methods */ -interface Edges { +export interface Edges { control: string[]; regular: string[]; } @@ -370,7 +370,7 @@ function findEdgeTargetsInGraph( }); } -interface HierarchyParams { +export interface HierarchyParams { verifyTemplate: boolean; seriesNodeMinSize: number; } @@ -640,7 +640,7 @@ function detectSeries(clusters: {[clusterId: string]: string[]}, * which is an array that contains objects with name, id, prefix, suffix, * and parent properties. */ - let candidatesDict = {}; + let candidatesDict: {[seriesName: string]: SeriesNode[]} = {}; // Group all nodes that have the same name, with the exception of a // number at the end of the name after an underscore, which is allowed to diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts index b4970a9cdce..50216bce226 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts @@ -65,7 +65,7 @@ export let SeriesNodeColors = { /** * Parameters that affect how the graph is rendered on the screen. */ -interface RenderGraphParams { +export interface RenderGraphParams { /** * Whether to extract high degree nodes from the core part of the graph. */ diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts index b5aafc55e5a..2ec0f6d5e96 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts @@ -28,7 +28,7 @@ export function detect(h, verifyTemplate): {[templateId: string]: string[]} { // Sort the templates by minimum level in the graph at which they appear, // as this leads to optimal setting of the colors of each template for // maximum differentiation. - return _(templates).pairs() + return <{[templateId: string]: string[]}> _(templates).pairs() .sortBy(function(pair) { return pair[1].level; }) @@ -101,6 +101,7 @@ function clusterSimilarSubgraphs(h: hierarchy.Hierarchy) { function groupTemplateAndAssignId(nnGroups, verifyTemplate) { // For each metanode, compare its subgraph (starting from shallower groups) // and assign template id. + let result: {[templateId: string]: {level: number, nodes: string[]}} = {}; return _.reduce(nnGroups, function(templates, nnGroupPair) { let signature = nnGroupPair[0], nnGroup = nnGroupPair[1].nodes, @@ -137,7 +138,7 @@ function groupTemplateAndAssignId(nnGroups, verifyTemplate) { }; }); return templates; - }, {}); + }, result); } function sortNodes(names: string[], graph: graphlib.Graph, diff --git a/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html b/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html index e4cd153113e..107e3ab7a09 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html +++ b/tensorflow/tensorboard/components/tf-graph-common/tf-graph-common.html @@ -1,6 +1,7 @@ + diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html index 5984cb67eca..51ea6497188 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html @@ -1,8 +1,7 @@ - - +