From e87e48dd8c5b84318122b5aa6d98945041c574d4 Mon Sep 17 00:00:00 2001 From: Yong Tang <yong.tang.github@outlook.com> Date: Fri, 10 May 2019 00:41:06 +0000 Subject: [PATCH] Add broadcasting support for `tf.where` This is a rework on 15982 Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- .../core/kernels/cwise_op_gpu_select.cu.cc | 29 ++- tensorflow/core/kernels/cwise_op_select.cc | 170 ++++++++++++++++-- tensorflow/core/kernels/cwise_ops.h | 12 ++ tensorflow/core/ops/math_ops.cc | 51 ++++++ .../python/kernel_tests/where_op_test.py | 157 +++++++++++++--- tensorflow/python/ops/array_ops.py | 48 ++++- 6 files changed, 425 insertions(+), 42 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc index 303d8e47913..ba872db2172 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc @@ -23,6 +23,22 @@ limitations under the License. namespace tensorflow { namespace functor { +template <typename T, int NDIMS> +struct BCastSelectFunctor<GPUDevice, T, NDIMS> { + void operator()(const GPUDevice& d, + typename TTypes<T, NDIMS>::Tensor output_tensor, + typename TTypes<bool, NDIMS>::ConstTensor cond_tensor, + typename TTypes<T, NDIMS>::ConstTensor then_tensor, + typename TTypes<T, NDIMS>::ConstTensor else_tensor, + typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast, + typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast, + typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + template <typename T> struct SelectFunctor<GPUDevice, T> { void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, @@ -89,10 +105,15 @@ struct BatchSelectFunctor<GPUDevice, T> { } }; -#define SELECT_FUNCTOR(T) \ - template struct SelectFunctor<GPUDevice, T>; \ - template struct SelectScalarFunctor<GPUDevice, T>; \ - template struct BatchSelectFunctor<GPUDevice, T>; +#define SELECT_FUNCTOR(T) \ + template struct SelectFunctor<GPUDevice, T>; \ + template struct SelectScalarFunctor<GPUDevice, T>; \ + template struct BatchSelectFunctor<GPUDevice, T>; \ + template struct BCastSelectFunctor<GPUDevice, T, 1>; \ + template struct BCastSelectFunctor<GPUDevice, T, 2>; \ + template struct BCastSelectFunctor<GPUDevice, T, 3>; \ + template struct BCastSelectFunctor<GPUDevice, T, 4>; \ + template struct BCastSelectFunctor<GPUDevice, T, 5>; SELECT_FUNCTOR(bool); SELECT_FUNCTOR(Eigen::half); diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 3b51563ca28..c85c9d0599f 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -143,21 +143,141 @@ class SelectOp : public OpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); }; +template <typename Device, typename T> +class SelectV2Op : public OpKernel { + public: + explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {} -#define REGISTER_SELECT(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - SelectOp<CPUDevice, type>); + void Compute(OpKernelContext* ctx) override { + const Tensor* cond; + const Tensor* then; + const Tensor* else_; + OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); + OP_REQUIRES_OK(ctx, ctx->input("t", &then)); + OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); + + // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), + // This matches the behavior of numpy. + // TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple + // 2-ary broadcast. + + // Combine `then` and `else`. + BCast then_else_bcast(BCast::FromShape(then->shape()), + BCast::FromShape(else_->shape()), false); + OP_REQUIRES(ctx, then_else_bcast.IsValid(), + errors::InvalidArgument( + "then ", then->shape().DebugString(), " and else ", + else_->shape().DebugString(), " must be broadcastable")); + // Combine `cond` with `then` and `else`. + BCast bcast( + BCast::FromShape(cond->shape()), + BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())), + false); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "condition ", cond->shape().DebugString(), ", then ", + then->shape().DebugString(), ", and else ", + else_->shape().DebugString(), " must be broadcastable")); + + // Broadcast `cond`, `then` and `else` to combined shape, + // in order to obtain the reshape. + BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(cond->shape()), false); + BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(then->shape()), false); + BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(else_->shape()), false); + OP_REQUIRES( + ctx, + cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(), + errors::InvalidArgument("condition ", cond->shape().DebugString(), + ", then ", then->shape().DebugString(), + ", and else ", else_->shape().DebugString(), + " must be broadcastable")); + + // Combined shape should be the final shape. + OP_REQUIRES( + ctx, + cond_bcast.output_shape() == bcast.output_shape() && + then_bcast.output_shape() == bcast.output_shape() && + else_bcast.output_shape() == bcast.output_shape(), + errors::InvalidArgument("condition ", cond->shape().DebugString(), + ", then ", then->shape().DebugString(), + ", and else ", else_->shape().DebugString(), + " must be broadcastable to the same shape")); + + Tensor* output = nullptr; + const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {"t", "e"}, "output", output_shape, &output)); + + if (output->NumElements() == 0) { + return; + } + +#define HANDLE_DIM(NDIMS) \ + { \ + functor::BCastSelectFunctor<Device, T, NDIMS> func; \ + func(ctx->eigen_device<Device>(), \ + output->shaped<T, NDIMS>(bcast.result_shape()), \ + cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \ + then->template shaped<T, NDIMS>(then_bcast.y_reshape()), \ + else_->template shaped<T, NDIMS>(else_bcast.y_reshape()), \ + BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()), \ + BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()), \ + BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast())); \ + } + + const int ndims = static_cast<int>(bcast.result_shape().size()); + switch (ndims) { + case 1: + HANDLE_DIM(1); + break; + case 2: + HANDLE_DIM(2); + break; + case 3: + HANDLE_DIM(3); + break; + case 4: + HANDLE_DIM(4); + break; + case 5: + HANDLE_DIM(5); + break; + default: + ctx->SetStatus(errors::Unimplemented( + "Broadcast between ", ctx->input(0).shape().DebugString(), " and ", + ctx->input(1).shape().DebugString(), " is not supported yet.")); + break; + } + return; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op); +}; + +#define REGISTER_SELECT(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SelectOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SelectV2Op<CPUDevice, type>); TF_CALL_ALL_TYPES(REGISTER_SELECT); #if GOOGLE_CUDA // Registration of the GPU implementations. -#define REGISTER_SELECT_GPU(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - SelectOp<GPUDevice, type>); +#define REGISTER_SELECT_GPU(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + SelectOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + SelectV2Op<GPUDevice, type>); REGISTER_SELECT_GPU(bool); REGISTER_SELECT_GPU(Eigen::half); @@ -174,9 +294,12 @@ REGISTER_SELECT_GPU(complex128); #ifdef TENSORFLOW_USE_SYCL // Registration of the SYCL implementations. -#define REGISTER_SELECT_SYCL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ +#define REGISTER_SELECT_SYCL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + SelectOp<SYCLDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ SelectOp<SYCLDevice, type>); REGISTER_SELECT_SYCL(float); @@ -324,10 +447,35 @@ struct BatchSelectFunctor<CPUDevice, T> { } }; +template <typename Device, typename T, int NDIMS> +struct BCastSelectFunctorBase { + void operator()(const Device& d, + typename TTypes<T, NDIMS>::Tensor output_tensor, + typename TTypes<bool, NDIMS>::ConstTensor cond_tensor, + typename TTypes<T, NDIMS>::ConstTensor then_tensor, + typename TTypes<T, NDIMS>::ConstTensor else_tensor, + typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast, + typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast, + typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + +template <typename T, int NDIMS> +struct BCastSelectFunctor<CPUDevice, T, NDIMS> + : BCastSelectFunctorBase<CPUDevice, T, NDIMS> {}; + #ifdef TENSORFLOW_USE_SYCL template <typename T> struct BatchSelectFunctor<SYCLDevice, T> : BatchSelectFunctorBase<SYCLDevice, T> {}; + +template <typename T, int NDIMS> +struct BCastSelectFunctor<SYCLDevice, T, NDIMS> + : BCastSelectFunctorBase<SYCLDevice, T, NDIMS> {}; + #endif // TENSORFLOW_USE_SYCL } // namespace functor diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 898147b3661..cc8eba69555 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1208,6 +1208,18 @@ struct BatchSelectFunctor { typename TTypes<T>::ConstMatrix else_flat_outer_dims); }; +template <typename Device, typename T, int NDIMS> +struct BCastSelectFunctor { + void operator()(const Device& d, + typename TTypes<T, NDIMS>::Tensor output_tensor, + typename TTypes<bool, NDIMS>::ConstTensor cond_tensor, + typename TTypes<T, NDIMS>::ConstTensor then_tensor, + typename TTypes<T, NDIMS>::ConstTensor else_tensor, + typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast, + typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast, + typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast); +}; + } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 29a93753e7d..3ff9bc09853 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -828,6 +828,57 @@ REGISTER_OP("Select") return Status::OK(); }); +REGISTER_OP("SelectV2") + .Input("condition: bool") + .Input("t: T") + .Input("e: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + auto* handle_data_1 = c->input_handle_shapes_and_types(1); + auto* handle_data_2 = c->input_handle_shapes_and_types(2); + // Merge handle shape and dtype if applicable. + if (handle_data_1 != nullptr && handle_data_2 != nullptr) { + const auto size = handle_data_1->size(); + std::vector<shape_inference::ShapeAndType> merged_handle_data(size); + if (size != handle_data_2->size()) { + return errors::InvalidArgument( + "Trying to merge handles pointing to different numbers of " + "tensors."); + } + + for (int i = 0; i < size; ++i) { + const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i]; + const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i]; + if (s1.dtype != s2.dtype) { + // TODO(apassos) resolve this in the manner of b/32476923 + return errors::InvalidArgument( + "Trying to merge handles pointing to different dtypes."); + } + merged_handle_data[i].dtype = s1.dtype; + TF_RETURN_IF_ERROR( + c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape)); + } + + c->set_output_handle_shapes_and_types(0, merged_handle_data); + } + + // The inputs 'cond', 'then', and 'else' must be broadcastable. + // TODO (yongtang): Consolidate 3-ary broadcast instead of + // multiple 2-ary broadcast. + ShapeHandle cond = c->input(0); + ShapeHandle then = c->input(1); + ShapeHandle else_ = c->input(2); + ShapeHandle other; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other)); + ShapeHandle output; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output)); + c->set_output(0, output); + return Status::OK(); + }); + // -------------------------------------------------------------------------- REGISTER_OP("MatMul") diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py index 56c13904113..00d916c06eb 100644 --- a/tensorflow/python/kernel_tests/where_op_test.py +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -37,10 +37,10 @@ from tensorflow.python.platform import test class WhereOpTest(test.TestCase): - def _testWhere(self, x, truth, expected_err_re=None): + def _testWhere(self, x, truth, expected_err_re=None, fn=array_ops.where): with self.cached_session(use_gpu=True): - ans = array_ops.where(x) - self.assertEqual([None, x.ndim], ans.get_shape().as_list()) + ans = fn(x) + self.assertTrue(ans.get_shape().is_compatible_with([None, x.ndim])) if expected_err_re is None: tf_ans = self.evaluate(ans) self.assertAllClose(tf_ans, truth, atol=1e-10) @@ -48,44 +48,40 @@ class WhereOpTest(test.TestCase): with self.assertRaisesOpError(expected_err_re): self.evaluate(ans) - def testWrongNumbers(self): + def _testWrongNumbers(self, fn=array_ops.where): with self.session(use_gpu=True): with self.assertRaises(ValueError): - array_ops.where([False, True], [1, 2], None) + fn([False, True], [1, 2], None) with self.assertRaises(ValueError): - array_ops.where([False, True], None, [1, 2]) + fn([False, True], None, [1, 2]) - @test_util.run_deprecated_v1 - def testBasicVec(self): + def _testBasicVec(self, fn=array_ops.where): x = np.asarray([True, False]) truth = np.asarray([[0]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) x = np.asarray([False, True, False]) truth = np.asarray([[1]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) x = np.asarray([False, False, True, False, True]) truth = np.asarray([[2], [4]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - @test_util.run_deprecated_v1 - def testRandomVec(self): + def _testRandomVec(self, fn=array_ops.where): x = np.random.rand(1000000) > 0.5 truth = np.vstack([np.where(x)[0].astype(np.int64)]).T - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - @test_util.run_deprecated_v1 - def testBasicMat(self): + def _testBasicMat(self, fn=array_ops.where): x = np.asarray([[True, False], [True, False]]) # Ensure RowMajor mode truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - @test_util.run_deprecated_v1 - def testBasic3Tensor(self): + def _testBasic3Tensor(self, fn=array_ops.where): x = np.asarray([[[True, False], [True, False]], [[False, True], [False, True]], [[False, False], [False, True]]]) @@ -94,15 +90,41 @@ class WhereOpTest(test.TestCase): truth = np.asarray( [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - def _testRandom(self, dtype, expected_err_re=None): + def _testRandom(self, dtype, expected_err_re=None, fn=array_ops.where): shape = [127, 33, 53] x = np.random.randn(*shape) + 1j * np.random.randn(*shape) x = (np.random.randn(*shape) > 0).astype(dtype) truth = np.where(np.abs(x) > 0) # Tuples of indices by axis. truth = np.vstack(truth).T # Convert to [num_true, indices]. - self._testWhere(x, truth, expected_err_re) + self._testWhere(x, truth, expected_err_re, fn) + + def _testThreeArgument(self, fn=array_ops.where): + x = np.array([[-2, 3, -1], [1, -3, -3]]) + np_val = np.where(x > 0, x * x, -x) + with self.test_session(use_gpu=True): + tf_val = self.evaluate(fn(constant_op.constant(x) > 0, x * x, -x)) + self.assertAllEqual(tf_val, np_val) + + def testWrongNumbers(self): + self._testWrongNumbers() + + @test_util.run_deprecated_v1 + def testBasicVec(self): + self._testBasicVec() + + @test_util.run_deprecated_v1 + def testRandomVec(self): + self._testRandomVec() + + @test_util.run_deprecated_v1 + def testBasicMat(self): + self._testBasicMat() + + @test_util.run_deprecated_v1 + def testBasic3Tensor(self): + self._testBasic3Tensor() @test_util.run_deprecated_v1 def testRandomBool(self): @@ -146,12 +168,95 @@ class WhereOpTest(test.TestCase): @test_util.run_deprecated_v1 def testThreeArgument(self): - x = np.array([[-2, 3, -1], [1, -3, -3]]) - np_val = np.where(x > 0, x * x, -x) - with self.session(use_gpu=True): - tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval() + self._testThreeArgument() + + def testV2WrongNumbers(self): + self._testWrongNumbers(array_ops.where_v2) + + def testV2BasicVec(self): + self._testBasicVec(array_ops.where_v2) + + def testV2RandomVec(self): + self._testRandomVec(array_ops.where_v2) + + def testV2BasicMat(self): + self._testBasicMat(array_ops.where_v2) + + def testV2Basic3Tensor(self): + self._testBasic3Tensor(array_ops.where_v2) + + def testV2RandomBool(self): + self._testRandom(np.bool, None, array_ops.where_v2) + + def testV2RandomInt32(self): + self._testRandom(np.int32, None, array_ops.where_v2) + + def testV2RandomInt64(self): + self._testRandom(np.int64, None, array_ops.where_v2) + + def testV2RandomFloat(self): + self._testRandom(np.float32, None, array_ops.where_v2) + + def testV2RandomDouble(self): + self._testRandom(np.float64, None, array_ops.where_v2) + + def testV2RandomComplex64(self): + self._testRandom(np.complex64, None, array_ops.where_v2) + + def testV2RandomComplex128(self): + self._testRandom(np.complex128, None, array_ops.where_v2) + + def testV2RandomUint8(self): + self._testRandom(np.uint8, None, array_ops.where_v2) + + def testV2RandomInt8(self): + self._testRandom(np.int8, None, array_ops.where_v2) + + def testV2RandomInt16(self): + self._testRandom(np.int16, None, array_ops.where_v2) + + def testV2ThreeArgument(self): + self._testThreeArgument(array_ops.where_v2) + + def testV2Broadcasting(self): + f = np.random.normal(0, 1, (3, 5, 1, 1)) + x = np.zeros((7, 11)) + y = np.ones((7, 11)) + np_val = np.where(f < 0, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate( + array_ops.where_v2(constant_op.constant(f) < 0, x, y)) self.assertAllEqual(tf_val, np_val) + def testV2ScalarBroadcasting(self): + x = np.zeros((7, 11)) + y = np.ones((7, 11)) + np_val = np.where(True, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate( + array_ops.where_v2( + constant_op.constant(True, dtype=dtypes.bool), x, y)) + self.assertAllEqual(tf_val, np_val) + + def testV2VectorBroadcasting(self): + x = np.zeros(7) + y = np.ones(7) + np_val = np.where([True], x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate( + array_ops.where_v2( + constant_op.constant([True], dtype=dtypes.bool), x, y)) + self.assertAllEqual(tf_val, np_val) + + def testV2PredBroadcasting(self): + pred = np.array([1, 0, 0]).reshape((3, 1)) + x = np.random.randn(3, 4) + y = np.random.randn(3, 4) + np_val = np.where(pred, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate(array_ops.where_v2(pred, x, y)) + self.assertAllClose(tf_val, np_val) + @test_util.run_deprecated_v1 def testBatchSelect(self): x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192] diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 11ef64d3a4f..28c1cb3a772 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -3163,7 +3163,11 @@ def squeeze_v2(input, axis=None, name=None): return squeeze(input, axis, name) -@tf_export("where") +@tf_export(v1=["where"]) +@deprecation.deprecated( + date=None, + instructions="Use tf.where in 2.0, " + "which has the same broadcast rule as np.where") @dispatch.add_dispatch_support def where(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. @@ -3217,6 +3221,48 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") +@tf_export("where", v1=["where_v2"]) +def where_v2(condition, x=None, y=None, name=None): + """Return the elements, either from `x` or `y`, depending on the `condition`. + + If both `x` and `y` are None, then this operation returns the coordinates of + true elements of `condition`. The coordinates are returned in a 2-D tensor + where the first dimension (rows) represents the number of true elements, and + the second dimension (columns) represents the coordinates of the true + elements. Keep in mind, the shape of the output tensor can vary depending on + how many true values there are in input. Indices are output in row-major + order. + If both non-None, `condition`, `x` and `y` must be broadcastable to the same + shape. + The `condition` tensor acts as a mask that chooses, based on the value at each + element, whether the corresponding element / row in the output should be taken + from `x` (if true) or `y` (if false). + Args: + condition: A `Tensor` of type `bool` + x: A Tensor which is of the same type as `y`, and may be broadcastable with + `condition` and `y`. + y: A Tensor which is of the same type as `x`, and may be broadcastable with + `condition` and `x`. + name: A name of the operation (optional). + + Returns: + A `Tensor` with the same type as `x` and `y`, and shape that + is broadcasted from `condition`, `x`, and `y`, if `x`, `y` are non-None. + A `Tensor` with shape `(num_true, dim_size(condition))`. + Raises: + ValueError: When exactly one of `x` or `y` is non-None. + """ + if x is None and y is None: + with ops.name_scope(name, "Where", [condition]) as name: + condition = ops.convert_to_tensor( + condition, preferred_dtype=dtypes.bool, name="condition") + return gen_array_ops.where(condition=condition, name=name) + elif x is not None and y is not None: + return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name) + else: + raise ValueError("x and y must both be non-None or both be None.") + + # pylint: disable=redefined-builtin @tf_export(v1=["reverse_sequence"]) @deprecation.deprecated_args(None,