diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6e76db02de0..62ae5ae78c6 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -208,7 +208,6 @@ filegroup( "//tensorflow/compiler/jit/kernels:all_files", "//tensorflow/compiler/jit/legacy_flags:all_files", "//tensorflow/compiler/jit/ops:all_files", - "//tensorflow/compiler/plugin/executor:all_files", "//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla/cc:all_files", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7ebd8422181..306e704415b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -15,10 +15,7 @@ package_group( ) package( - default_visibility = [ - ":internal", - "//tensorflow/compiler/plugin/executor:__pkg__", - ], + default_visibility = [":internal"], ) load("//tensorflow:tensorflow.bzl", "cc_header_only_library") diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 97f3512a6c4..ed204b81821 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -2,7 +2,6 @@ licenses(["notice"]) # Apache 2.0 package( default_visibility = [ - "//tensorflow/compiler/plugin/executor:__pkg__", "//tensorflow/compiler/tf2xla:internal", ], ) diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD index 2e5875705f2..9bc706abdf6 100644 --- a/tensorflow/compiler/plugin/executor/BUILD +++ b/tensorflow/compiler/plugin/executor/BUILD @@ -11,11 +11,9 @@ cc_library( "*.h", ]), deps = [ - "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_jit_headers_lib", - "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:xla_headers_lib", - "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:hlo_evaluator", "//third_party/eigen3", "@local_config_cuda//cuda:cuda_headers", "@protobuf//:protobuf_headers", diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 432b24756d2..044857d4222 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -175,11 +175,6 @@ tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], - # TODO(b/62962492): Test fails with assertion error. - tags = [ - "manual", - "notap", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -461,11 +456,6 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", ], - # TODO(b/62961789): Test fails with SIGABRT - tags = [ - "manual", - "notap", - ], ) cc_library( @@ -534,12 +524,8 @@ cuda_py_test( # --dump_graph_dir, and the config file was written by hand. # # Run the following to build a minimal benchmark of the computation on Android: -# $ bazel build -c opt --cxxopt='-std=c++11' --linkopt='-lm' \ -# --cpu=armeabi-v7a \ -# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ -# --crosstool_top=//external:android/crosstool \ -# //tensorflow/compiler/tests:lstm_layer_inference_benchmark - +# $ bazel build -c opt --config=android_arm \ +# third_party/tensorflow/compiler/tests:lstm_layer_inference_benchmark # # Currently the resulting binary size is ~190KB tf_library( diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index f53f38f3cf8..ce8518267f3 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -260,7 +260,11 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): For example: ```python +<<<<<<< HEAD mapping_string = tf.constant(["emerson", "lake", "palmer"]) +======= + mapping_string = tf.constant(["emerson", "lake", "palmer") +>>>>>>> 338a7ead4475d6b97b420d6d1c56ff66815e3e7b indices = tf.constant([1, 5], tf.int64) values = tf.contrib.lookup.index_to_string( indices, mapping=mapping_string, default_value="UNKNOWN") diff --git a/tensorflow/core/kernels/sparse_reduce_sum_op.cc b/tensorflow/core/kernels/sparse_reduce_sum_op.cc new file mode 100644 index 00000000000..074aab9f9e2 --- /dev/null +++ b/tensorflow/core/kernels/sparse_reduce_sum_op.cc @@ -0,0 +1,305 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/sparse_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +// TODO(b/31496047): Fix non-standard include order. +#include <numeric> // clang-format off + +using tensorflow::sparse::SparseTensor; +using tensorflow::gtl::ArraySlice; + +namespace tensorflow { + +struct ReduceDetails { + // The dimensions to call Reorder() with. + std::vector<int64> reorder_dims; + + // The dimensions to call group() with after Reorder(). + std::vector<int64> group_by_dims; + + // The shape after reduction. + TensorShape reduced_shape; +}; + +// Compute common reduce parameters that'll be used for SparseTensor +// reductions. Usage: +// ReduceDetails reduction = SparseTensorReduceHelper(sp, axes, keep_dims); +// sp.Reorder(reduction.reorder_dims); +// for (const auto& g : sp.group(reduction.group_by_dims)) { +// ... +// } +// // Set output shape to reduction.reduced_shape. +ReduceDetails SparseTensorReduceHelper(const SparseTensor &sp, + gtl::ArraySlice<int32> axes_slice, + bool keep_dims) { + ReduceDetails reduction; + + std::vector<int32> reduction_axes(axes_slice.begin(), axes_slice.end()); + int ndims = sp.dims(); + for (int64 i = 0; i < reduction_axes.size(); ++i) { + reduction_axes[i] = (reduction_axes[i] + ndims) % ndims; + } + std::sort(reduction_axes.begin(), reduction_axes.end()); + + // (0) Calculate the grouping dimensions: + // group_by_dims == {0, .., NDIMS-1} \ reduction_axes. + std::vector<int64> perm(ndims); + std::iota(perm.begin(), perm.end(), 0); + + // Requires perm and reduction_axes_ be sorted; group_by_dims will be + // sorted as well. + std::set_difference( + perm.begin(), perm.end(), reduction_axes.begin(), reduction_axes.end(), + std::inserter(reduction.group_by_dims, reduction.group_by_dims.begin())); + + // Now append the rest of the axes (the complement of group_by_dims_); + // result is used by Reorder(). + reduction.reorder_dims = reduction.group_by_dims; + std::set_difference(perm.begin(), perm.end(), reduction.group_by_dims.begin(), + reduction.group_by_dims.end(), + std::back_inserter(reduction.reorder_dims)); + + // (1) Calculate the shape after reduction. + auto sp_shape = sp.shape(); + std::vector<int64> out_dim_sizes; + if (keep_dims) { + out_dim_sizes.reserve(ndims); + auto beg = reduction.group_by_dims.begin(); + auto end = reduction.group_by_dims.end(); + for (int d = 0; d < ndims; ++d) { + if (std::find(beg, end, d) == end) { + out_dim_sizes.push_back(1); // A reduced axis. + } else { + out_dim_sizes.push_back(sp_shape[d]); + } + } + } else { + out_dim_sizes = sp.PickDims(reduction.group_by_dims); + } + + reduction.reduced_shape = TensorShape(out_dim_sizes); + return reduction; +} + +Status ValidateInputs(const Tensor *shape_t, const Tensor *reduction_axes_t) { + // indices and values are validated in SparseTensor ctor. + if (!TensorShapeUtils::IsVector(shape_t->shape())) { + return errors::InvalidArgument( + "Expected input_shape to be a vector; got shape: ", + shape_t->shape().DebugString()); + } + if (!TensorShapeUtils::IsScalar(reduction_axes_t->shape()) && + !TensorShapeUtils::IsVector(reduction_axes_t->shape())) { + return errors::InvalidArgument( + "Expected reduction_axes to be a scalar or a vector; got shape: ", + reduction_axes_t->shape().DebugString()); + } + + const auto reduction_axes_flat = reduction_axes_t->flat<int32>(); + for (int64 i = 0; i < reduction_axes_flat.size(); i++) { + int32 axis = reduction_axes_flat(i); + if (axis < -shape_t->NumElements() || axis >= shape_t->NumElements()) { + return errors::InvalidArgument("Invalid reduction dimension ", axis, + ", for input with ", + shape_t->NumElements(), " dimensions."); + } + } + + return Status::OK(); +} + +template <typename T> +class SparseReduceSumOp : public OpKernel { + public: + explicit SparseReduceSumOp(OpKernelConstruction *ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); + } + + void Compute(OpKernelContext *ctx) override { + const Tensor *indices_t, *values_t, *shape_t, *reduction_axes_t; + OP_REQUIRES_OK(ctx, ctx->input("input_indices", &indices_t)); + OP_REQUIRES_OK(ctx, ctx->input("input_values", &values_t)); + OP_REQUIRES_OK(ctx, ctx->input("input_shape", &shape_t)); + OP_REQUIRES_OK(ctx, ctx->input("reduction_axes", &reduction_axes_t)); + + OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t)); + + // TODO(zongheng): we will call Reorder() below, which will modify + // in-place the underlying indices and values buffers. To avoid + // surprises of this kernel being stateful, we work around the above by + // making deep copies here. Remove this if/when we change Reorder()'s + // semantics. + const auto shape_vec = shape_t->vec<int64>(); + SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t), + TensorShape(shape_vec)); + ReduceDetails reduction = SparseTensorReduceHelper( + sp, reduction_axes_t->flat<int32>(), keep_dims_); + + Tensor *out_values; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, reduction.reduced_shape, &out_values)); + auto out_flat = out_values->flat<T>(); + out_flat.setZero(); + + Tensor tmp_group_sum; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, + TensorShape({}), &tmp_group_sum)); + auto group_sum = tmp_group_sum.scalar<T>(); + + // Compute strides, and use it to convert coords to flat index. The + // coordinates returned by .group() have the same ndims as group_by_dims. + gtl::InlinedVector<int64, 8> output_strides(reduction.group_by_dims.size()); + if (!output_strides.empty()) { // Do this iff we don't reduce all. + output_strides.back() = 1; + for (int d = output_strides.size() - 2; d >= 0; --d) { + output_strides[d] = + output_strides[d + 1] * shape_vec(reduction.group_by_dims[d + 1]); + } + } + + auto CoordinatesToFlatIndex = [](ArraySlice<int64> coords, + ArraySlice<int64> strides) { + if (strides.empty()) { // Reduce all. + return 0LL; + } + CHECK_EQ(coords.size(), strides.size()); + int64 idx = 0; + for (int i = 0; i < coords.size(); ++i) { + idx += coords[i] * strides[i]; + } + return idx; + }; + + // Each group maps one-on-one onto a value in the reduced tensor. + // g.group() provides the coordinates of a particular reduced value. + sp.Reorder<T>(reduction.reorder_dims); + for (const auto &g : sp.group(reduction.group_by_dims)) { + group_sum.device(ctx->eigen_cpu_device()) = g.template values<T>().sum(); + const int64 idx = CoordinatesToFlatIndex(g.group(), output_strides); + out_flat(idx) = group_sum(); + VLOG(2) << "coords: " << str_util::Join(g.group(), ",") + << "; idx: " << idx << "; group sum: " << group_sum(); + } + } + + private: + // True if the number of dimensions should be maintained. + bool keep_dims_; +}; + +#define REGISTER_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseReduceSum").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + SparseReduceSumOp<T>) +TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +template <typename T> +class SparseReduceSumSparseOp : public OpKernel { + public: + explicit SparseReduceSumSparseOp(OpKernelConstruction *ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); + } + + void Compute(OpKernelContext *ctx) override { + const Tensor *indices_t, *values_t, *shape_t, *reduction_axes_t; + OP_REQUIRES_OK(ctx, ctx->input("input_indices", &indices_t)); + OP_REQUIRES_OK(ctx, ctx->input("input_values", &values_t)); + OP_REQUIRES_OK(ctx, ctx->input("input_shape", &shape_t)); + OP_REQUIRES_OK(ctx, ctx->input("reduction_axes", &reduction_axes_t)); + + OP_REQUIRES_OK(ctx, ValidateInputs(shape_t, reduction_axes_t)); + + SparseTensor sp(tensor::DeepCopy(*indices_t), tensor::DeepCopy(*values_t), + TensorShape(shape_t->vec<int64>())); + ReduceDetails reduction = SparseTensorReduceHelper( + sp, reduction_axes_t->flat<int32>(), keep_dims_); + + sp.Reorder<T>(reduction.reorder_dims); + // Count nnzs in the output SparseTensor. + int64 nnz = 0; + auto iter = sp.group(reduction.group_by_dims); + for (auto it = iter.begin(); it != iter.end(); ++it) { + nnz++; + } + + Tensor *out_indices_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output( + 0, TensorShape({nnz, reduction.reduced_shape.dims()}), + &out_indices_t)); + typename TTypes<int64>::Matrix out_indices_mat = + out_indices_t->matrix<int64>(); + // For keep_dims. We don't explicitly set dim fields for reduced dims below. + out_indices_mat.setZero(); + + Tensor *out_values_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(1, TensorShape({nnz}), &out_values_t)); + auto out_flat = out_values_t->flat<T>(); + + Tensor tmp_group_sum; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, + TensorShape({}), &tmp_group_sum)); + auto group_sum = tmp_group_sum.scalar<T>(); + int64 i = 0; + for (const auto &g : sp.group(reduction.group_by_dims)) { + group_sum.device(ctx->eigen_cpu_device()) = g.template values<T>().sum(); + std::vector<int64> group = g.group(); + for (int64 j = 0; j < group.size(); j++) { + if (keep_dims_) { + out_indices_mat(i, reduction.group_by_dims[j]) = group[j]; + } else { + out_indices_mat(i, j) = group[j]; + } + } + out_flat(i) = group_sum(); + i++; + VLOG(2) << "coords: " << str_util::Join(g.group(), ",") + << "; group sum: " << group_sum(); + } + + Tensor *out_shape_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 2, TensorShape({reduction.reduced_shape.dims()}), + &out_shape_t)); + auto out_shape_flat = out_shape_t->flat<int64>(); + auto out_dim_sizes = reduction.reduced_shape.dim_sizes(); + std::copy(out_dim_sizes.begin(), out_dim_sizes.end(), &out_shape_flat(0)); + } + + private: + // True if the number of dimensions should be maintained. + bool keep_dims_; +}; + +#define REGISTER_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseReduceSumSparse").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + SparseReduceSumSparseOp<T>) +TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 936348e01d2..22b18b9cde0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3056,7 +3056,6 @@ py_test( srcs = ["client/session_clusterspec_prop_test.py"], srcs_version = "PY2AND3", tags = [ - "no_gpu", "no_pip_gpu", ], deps = [ @@ -3081,7 +3080,6 @@ py_test( srcs = ["client/session_list_devices_test.py"], srcs_version = "PY2AND3", tags = [ - "no_gpu", "no_pip_gpu", ], deps = [ diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index 66643622260..14eb2cba68a 100644 --- a/tensorflow/python/kernel_tests/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import numpy as np +import unittest from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -605,6 +606,7 @@ class SparseReduceTest(test_util.TensorFlowTestCase): self._compare(sp_t, reduction_axes, ndims, True, False) self._compare(sp_t, reduction_axes, ndims, True, True) + @unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug") def testSimpleAndRandomInputs(self): if np.__version__ == "1.13.0": self.skipTest("numpy 1.13.0 bug") @@ -644,6 +646,7 @@ class SparseReduceTest(test_util.TensorFlowTestCase): with self.assertRaisesOpError("Invalid reduction dimension 2"): sparse_ops.sparse_reduce_max(sp_t, 2).eval() + @unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug") def testGradient(self): if np.__version__ == "1.13.0": self.skipTest("numpy 1.13.0 bug")