From eb856d1a09d2a5c3daa54a17395d51a25f8bb8d0 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:23:58 +0000 Subject: [PATCH 01/21] Add broadcasting support for `tf.where` This fix tries to address the issue raised in 9284 where there was no broadcasting support for `tf.where`. This fix adds the support so that the behavior of `tf.where` matches `np.where`. This fix fixes 9284. Signed-off-by: Yong Tang --- tensorflow/core/kernels/cwise_op_select.cc | 175 +++++++++++++++++++-- 1 file changed, 164 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index dd4e4ea547e..54159b88d6c 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -143,21 +143,138 @@ class SelectOp : public OpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); }; +template +class SelectV2Op : public OpKernel { + public: + explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {} -#define REGISTER_SELECT(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_CPU).TypeConstraint("T"), \ - SelectOp); + void Compute(OpKernelContext* ctx) override { + const Tensor* cond; + const Tensor* then; + const Tensor* else_; + OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); + OP_REQUIRES_OK(ctx, ctx->input("t", &then)); + OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); + + // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), + // This matches the behavior of numpy. + // TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple + // 2-ary broadcast. + + // Combine `then` and `else`. + BCast elem_bcast(BCast::FromShape(then->shape()), + BCast::FromShape(else_->shape()), false); + OP_REQUIRES(ctx, elem_bcast.IsValid(), + errors::InvalidArgument( + "then ", then->shape().DebugString(), " and else ", + else_->shape().DebugString(), " must be broadcastable")); + // Combine `cond` with `then` and `else`. + BCast bcast(BCast::FromShape(cond->shape()), + BCast::FromShape(BCast::ToShape(elem_bcast.output_shape())), + false); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "condition ", cond->shape().DebugString(), ", then ", + then->shape().DebugString(), ", and else ", + else_->shape().DebugString(), " must be broadcastable")); + + // Broadcast `cond`, `then` and `else` to combined shape, + // in order to obtain the reshape. + BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(cond->shape()), false); + BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(then->shape()), false); + BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(else_->shape()), false); + OP_REQUIRES(ctx, cond_bcast.IsValid() && then_bcast.IsValid() && + else_bcast.IsValid(), + errors::InvalidArgument( + "condition ", cond->shape().DebugString(), ", then ", + then->shape().DebugString(), ", and else ", + else_->shape().DebugString(), " must be broadcastable")); + + // Combined shape should be the final shape. + OP_REQUIRES( + ctx, cond_bcast.output_shape() == bcast.output_shape() && + then_bcast.output_shape() == bcast.output_shape() && + else_bcast.output_shape() == bcast.output_shape(), + errors::InvalidArgument("condition ", cond->shape().DebugString(), + ", then ", then->shape().DebugString(), + ", and else ", else_->shape().DebugString(), + " must be broadcastable to the same shape")); + + Tensor* output = nullptr; + const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {"t", "e"}, "output", output_shape, &output)); + + if (output->NumElements() == 0) { + return; + } + +#define HANDLE_DIM(NDIMS) \ + { \ + functor::BCastSelectFunctor func; \ + func(ctx->eigen_device(), \ + output->shaped(bcast.result_shape()), \ + cond->template shaped(cond_bcast.y_reshape()), \ + then->template shaped(then_bcast.y_reshape()), \ + else_->template shaped(else_bcast.y_reshape()), \ + BCast::ToIndexArray(cond_bcast.y_bcast()), \ + BCast::ToIndexArray(then_bcast.y_bcast()), \ + BCast::ToIndexArray(else_bcast.y_bcast())); \ + } + + const int ndims = static_cast(bcast.result_shape().size()); + switch (ndims) { + case 1: + HANDLE_DIM(1); + break; + case 2: + HANDLE_DIM(2); + break; + case 3: + HANDLE_DIM(3); + break; + case 4: + HANDLE_DIM(4); + break; + case 5: + HANDLE_DIM(5); + break; + default: + ctx->SetStatus(errors::Unimplemented( + "Broadcast between ", ctx->input(0).shape().DebugString(), " and ", + ctx->input(1).shape().DebugString(), " is not supported yet.")); + break; + } + return; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op); +}; + +#define REGISTER_SELECT(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_CPU).TypeConstraint("T"), \ + SelectOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_CPU).TypeConstraint("T"), \ + SelectV2Op); TF_CALL_ALL_TYPES(REGISTER_SELECT); #if GOOGLE_CUDA // Registration of the GPU implementations. -#define REGISTER_SELECT_GPU(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_GPU).TypeConstraint("T"), \ - SelectOp); +#define REGISTER_SELECT_GPU(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_GPU).TypeConstraint("T"), \ + SelectOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_GPU).TypeConstraint("T"), \ + SelectV2Op); REGISTER_SELECT_GPU(bool); REGISTER_SELECT_GPU(Eigen::half); @@ -174,9 +291,12 @@ REGISTER_SELECT_GPU(complex128); #ifdef TENSORFLOW_USE_SYCL // Registration of the SYCL implementations. -#define REGISTER_SELECT_SYCL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_SYCL).TypeConstraint("T"), \ +#define REGISTER_SELECT_SYCL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_SYCL).TypeConstraint("T"), \ + SelectOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint("T"), \ SelectOp); REGISTER_SELECT_SYCL(float); @@ -324,10 +444,43 @@ struct BatchSelectFunctor { } }; +template +struct BCastSelectFunctor { + void operator()(const CPUDevice& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + #ifdef TENSORFLOW_USE_SYCL template struct BatchSelectFunctor : BatchSelectFunctorBase {}; + +template +struct BCastSelectFunctor { + void operator()(const SYCLDevice& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + #endif // TENSORFLOW_USE_SYCL } // namespace functor From 5cb76656be44117be687f5f206dfe28b29598449 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:24:30 +0000 Subject: [PATCH 02/21] Update shape function for `tf.where` / `SelectOp` Signed-off-by: Yong Tang --- tensorflow/core/ops/math_ops.cc | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 6f261dc1b18..8ec4665897b 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -758,6 +758,57 @@ REGISTER_OP("Select") return Status::OK(); }); +REGISTER_OP("SelectV2") + .Input("condition: bool") + .Input("t: T") + .Input("e: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + auto* handle_data_1 = c->input_handle_shapes_and_types(1); + auto* handle_data_2 = c->input_handle_shapes_and_types(2); + // Merge handle shape and dtype if applicable. + if (handle_data_1 != nullptr && handle_data_2 != nullptr) { + const auto size = handle_data_1->size(); + std::vector merged_handle_data(size); + if (size != handle_data_2->size()) { + return errors::InvalidArgument( + "Trying to merge handles pointing to different numbers of " + "tensors."); + } + + for (int i = 0; i < size; ++i) { + const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i]; + const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i]; + if (s1.dtype != s2.dtype) { + // TODO(apassos) resolve this in the manner of b/32476923 + return errors::InvalidArgument( + "Trying to merge handles pointing to different dtypes."); + } + merged_handle_data[i].dtype = s1.dtype; + TF_RETURN_IF_ERROR( + c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape)); + } + + c->set_output_handle_shapes_and_types(0, merged_handle_data); + } + + // The inputs 'cond', 'then', and 'else' must be broadcastable. + // TODO (yongtang): Consolidate 3-ary broadcast instead of + // multiple 2-ary broadcast. + ShapeHandle cond = c->input(0); + ShapeHandle then = c->input(1); + ShapeHandle else_ = c->input(2); + ShapeHandle other; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other)); + ShapeHandle output; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output)); + c->set_output(0, output); + return Status::OK(); + }); + // -------------------------------------------------------------------------- REGISTER_OP("MatMul") From 5ce52a0ea1cc3d237b6c133efc07e4b15d4531b3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:26:35 +0000 Subject: [PATCH 03/21] Add template for BCastSelectFunctor Signed-off-by: Yong Tang --- tensorflow/core/kernels/cwise_ops.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 66ba827a901..ea22f18e48f 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1078,6 +1078,18 @@ struct BatchSelectFunctor { typename TTypes::ConstMatrix else_flat_outer_dims); }; +template +struct BCastSelectFunctor { + void operator()(const Device& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast); +}; + } // end namespace functor } // end namespace tensorflow From 94b1e3e874847fe85141df62b59cae5637e0c17a Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:26:59 +0000 Subject: [PATCH 04/21] Add GPU support for where_v2 Signed-off-by: Yong Tang --- .../core/kernels/cwise_op_gpu_select.cu.cc | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc index 303d8e47913..b371b468cdf 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc @@ -23,6 +23,21 @@ limitations under the License. namespace tensorflow { namespace functor { +template +struct BCastSelectFunctor { + void operator()(const GPUDevice& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast).select( + then_tensor.broadcast(then_bcast), else_tensor.broadcast(else_bcast)); + } +}; + template struct SelectFunctor { void operator()(const GPUDevice& d, typename TTypes::Flat out, @@ -89,10 +104,15 @@ struct BatchSelectFunctor { } }; -#define SELECT_FUNCTOR(T) \ - template struct SelectFunctor; \ - template struct SelectScalarFunctor; \ - template struct BatchSelectFunctor; +#define SELECT_FUNCTOR(T) \ + template struct SelectFunctor; \ + template struct SelectScalarFunctor; \ + template struct BatchSelectFunctor; \ + template struct BCastSelectFunctor; \ + template struct BCastSelectFunctor; \ + template struct BCastSelectFunctor; \ + template struct BCastSelectFunctor; \ + template struct BCastSelectFunctor; SELECT_FUNCTOR(bool); SELECT_FUNCTOR(Eigen::half); From 2250c49e12e10b646badcb47b9c3066949749261 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:27:20 +0000 Subject: [PATCH 05/21] Add test case for broadcasting support of `where_v2` Signed-off-by: Yong Tang --- .../python/kernel_tests/where_op_test.py | 136 +++++++++++++++--- 1 file changed, 115 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py index fca45c3ece4..fd0a719ad00 100644 --- a/tensorflow/python/kernel_tests/where_op_test.py +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -36,9 +36,9 @@ from tensorflow.python.platform import test class WhereOpTest(test.TestCase): - def _testWhere(self, x, truth, expected_err_re=None): + def _testWhere(self, x, truth, expected_err_re=None, fn=array_ops.where): with self.cached_session(use_gpu=True): - ans = array_ops.where(x) + ans = fn(x) self.assertEqual([None, x.ndim], ans.get_shape().as_list()) if expected_err_re is None: tf_ans = ans.eval() @@ -47,40 +47,40 @@ class WhereOpTest(test.TestCase): with self.assertRaisesOpError(expected_err_re): ans.eval() - def testWrongNumbers(self): + def _testWrongNumbers(self, fn=array_ops.where): with self.session(use_gpu=True): with self.assertRaises(ValueError): - array_ops.where([False, True], [1, 2], None) + fn([False, True], [1, 2], None) with self.assertRaises(ValueError): - array_ops.where([False, True], None, [1, 2]) + fn([False, True], None, [1, 2]) - def testBasicVec(self): + def _testBasicVec(self, fn=array_ops.where): x = np.asarray([True, False]) truth = np.asarray([[0]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) x = np.asarray([False, True, False]) truth = np.asarray([[1]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) x = np.asarray([False, False, True, False, True]) truth = np.asarray([[2], [4]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - def testRandomVec(self): + def _testRandomVec(self, fn=array_ops.where): x = np.random.rand(1000000) > 0.5 truth = np.vstack([np.where(x)[0].astype(np.int64)]).T - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - def testBasicMat(self): + def _testBasicMat(self, fn=array_ops.where): x = np.asarray([[True, False], [True, False]]) # Ensure RowMajor mode truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - def testBasic3Tensor(self): + def _testBasic3Tensor(self, fn=array_ops.where): x = np.asarray([[[True, False], [True, False]], [[False, True], [False, True]], [[False, False], [False, True]]]) @@ -89,15 +89,37 @@ class WhereOpTest(test.TestCase): truth = np.asarray( [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - def _testRandom(self, dtype, expected_err_re=None): + def _testRandom(self, dtype, expected_err_re=None, fn=array_ops.where): shape = [127, 33, 53] x = np.random.randn(*shape) + 1j * np.random.randn(*shape) x = (np.random.randn(*shape) > 0).astype(dtype) truth = np.where(np.abs(x) > 0) # Tuples of indices by axis. truth = np.vstack(truth).T # Convert to [num_true, indices]. - self._testWhere(x, truth, expected_err_re) + self._testWhere(x, truth, expected_err_re, fn) + + def _testThreeArgument(self, fn=array_ops.where): + x = np.array([[-2, 3, -1], [1, -3, -3]]) + np_val = np.where(x > 0, x * x, -x) + with self.test_session(use_gpu=True): + tf_val = fn(constant_op.constant(x) > 0, x * x, -x).eval() + self.assertAllEqual(tf_val, np_val) + + def testWrongNumbers(self): + self._testWrongNumbers() + + def testBasicVec(self): + self._testBasicVec() + + def testRandomVec(self): + self._testRandomVec() + + def testBasicMat(self): + self._testBasicMat() + + def testBasic3Tensor(self): + self._testBasic3Tensor() def testRandomBool(self): self._testRandom(np.bool) @@ -130,10 +152,82 @@ class WhereOpTest(test.TestCase): self._testRandom(np.int16) def testThreeArgument(self): - x = np.array([[-2, 3, -1], [1, -3, -3]]) - np_val = np.where(x > 0, x * x, -x) - with self.session(use_gpu=True): - tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval() + self._testThreeArgument() + + def testV2WrongNumbers(self): + self._testWrongNumbers(array_ops.where_v2) + + def testV2BasicVec(self): + self._testBasicVec(array_ops.where_v2) + + def testV2RandomVec(self): + self._testRandomVec(array_ops.where_v2) + + def testV2BasicMat(self): + self._testBasicMat(array_ops.where_v2) + + def testV2Basic3Tensor(self): + self._testBasic3Tensor(array_ops.where_v2) + + def testV2RandomBool(self): + self._testRandom(np.bool, None, array_ops.where_v2) + + def testV2RandomInt32(self): + self._testRandom(np.int32, None, array_ops.where_v2) + + def testV2RandomInt64(self): + self._testRandom(np.int64, None, array_ops.where_v2) + + def testV2RandomFloat(self): + self._testRandom(np.float32, None, array_ops.where_v2) + + def testV2RandomDouble(self): + self._testRandom(np.float64, None, array_ops.where_v2) + + def testV2RandomComplex64(self): + self._testRandom(np.complex64, None, array_ops.where_v2) + + def testV2RandomComplex128(self): + self._testRandom(np.complex128, None, array_ops.where_v2) + + def testV2RandomUint8(self): + self._testRandom(np.uint8, None, array_ops.where_v2) + + def testV2RandomInt8(self): + self._testRandom(np.int8, None, array_ops.where_v2) + + def testV2RandomInt16(self): + self._testRandom(np.int16, None, array_ops.where_v2) + + def testV2ThreeArgument(self): + self._testThreeArgument(array_ops.where_v2) + + def testV2Broadcasting(self): + f = np.random.normal(0, 1, (3, 5, 1, 1)) + x = np.zeros((7, 11)) + y = np.ones((7, 11)) + np_val = np.where(f < 0, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate( + array_ops.where_v2(constant_op.constant(f) < 0, x, y)) + self.assertAllEqual(tf_val, np_val) + + def testV2ScalarBroadcasting(self): + x = np.zeros((7, 11)) + y = np.ones((7, 11)) + np_val = np.where(True, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate(array_ops.where_v2( + constant_op.constant(True, dtype=dtypes.bool), x, y)) + self.assertAllEqual(tf_val, np_val) + + def testV2VectorBroadcasting(self): + x = np.zeros(7) + y = np.ones(7) + np_val = np.where([True], x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate(array_ops.where_v2( + constant_op.constant([True], dtype=dtypes.bool), x, y)) self.assertAllEqual(tf_val, np_val) def testBatchSelect(self): From 399d5c20b3bbd1aefafbb3ce70669481ed6732d2 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:28:54 +0000 Subject: [PATCH 06/21] Define where_v2 in array_ops.py Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 6fdc50733a1..288d930289e 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2631,6 +2631,45 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") +def where_v2(condition, x=None, y=None, name=None): + """Return the elements, either from `x` or `y`, depending on the `condition`. + If both `x` and `y` are None, then this operation returns the coordinates of + true elements of `condition`. The coordinates are returned in a 2-D tensor + where the first dimension (rows) represents the number of true elements, and + the second dimension (columns) represents the coordinates of the true + elements. Keep in mind, the shape of the output tensor can vary depending on + how many true values there are in input. Indices are output in row-major + order. + If both non-None, `condition`, `x` and `y` must be broadcastable to the same + shape. + The `condition` tensor acts as a mask that chooses, based on the value at each + element, whether the corresponding element / row in the output should be taken + from `x` (if true) or `y` (if false). + Args: + condition: A `Tensor` of type `bool` + x: A Tensor which is of the same type as `y`, and may be + broadcastable with `condition` and `y`. + y: A Tensor which is of the same type as `x`, and may be + broadcastable with `condition` and `x`. + name: A name of the operation (optional). + Returns: + A `Tensor` with the same type as `x` and `y`, and shape that + is broadcasted from `condition`, `x`, and `y`, if `x`, `y` are non-None. + A `Tensor` with shape `(num_true, dim_size(condition))`. + Raises: + ValueError: When exactly one of `x` or `y` is non-None. + """ + if x is None and y is None: + with ops.name_scope(name, "Where", [condition]) as name: + condition = ops.convert_to_tensor( + condition, preferred_dtype=dtypes.bool, name="condition") + return gen_array_ops.where(condition=condition, name=name) + elif x is not None and y is not None: + return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name) + else: + raise ValueError("x and y must both be non-None or both be None.") + + # pylint: disable=redefined-builtin @tf_export("reverse_sequence") @deprecation.deprecated_args( From 75b4505d4239c65f3fd64d1ba26027cd40bc4a41 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:29:12 +0000 Subject: [PATCH 07/21] Add broadcasting support for `tf.where` Signed-off-by: Yong Tang --- tensorflow/contrib/framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index e72e50585a3..d457c170740 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -126,7 +126,7 @@ from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['nest'] +_allowed_symbols = ['nest', 'broadcast_to', 'where'] _nest_allowed_symbols = [ 'assert_same_structure', 'is_sequence', From 0ec97a70ef809d36a3e0385245f1d8b6165bd656 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:40:46 +0000 Subject: [PATCH 08/21] Fix `Experimental clang-format Check` Signed-off-by: Yong Tang --- .../core/kernels/cwise_op_gpu_select.cu.cc | 5 ++-- tensorflow/core/kernels/cwise_op_select.cc | 28 +++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc index b371b468cdf..ba872db2172 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc @@ -33,8 +33,9 @@ struct BCastSelectFunctor { typename Eigen::array cond_bcast, typename Eigen::array then_bcast, typename Eigen::array else_bcast) { - output_tensor.device(d) = cond_tensor.broadcast(cond_bcast).select( - then_tensor.broadcast(then_bcast), else_tensor.broadcast(else_bcast)); + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); } }; diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 54159b88d6c..dd6a9a17392 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -255,26 +255,26 @@ class SelectV2Op : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op); }; -#define REGISTER_SELECT(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_CPU).TypeConstraint("T"), \ - SelectOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("SelectV2").Device(DEVICE_CPU).TypeConstraint("T"), \ - SelectV2Op); +#define REGISTER_SELECT(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_CPU).TypeConstraint("T"), \ + SelectOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_CPU).TypeConstraint("T"), \ + SelectV2Op); TF_CALL_ALL_TYPES(REGISTER_SELECT); #if GOOGLE_CUDA // Registration of the GPU implementations. -#define REGISTER_SELECT_GPU(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_GPU).TypeConstraint("T"), \ - SelectOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("SelectV2").Device(DEVICE_GPU).TypeConstraint("T"), \ - SelectV2Op); +#define REGISTER_SELECT_GPU(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_GPU).TypeConstraint("T"), \ + SelectOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_GPU).TypeConstraint("T"), \ + SelectV2Op); REGISTER_SELECT_GPU(bool); REGISTER_SELECT_GPU(Eigen::half); From e5edd77388b24f8ecf3ed09829b33324670523ab Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 17:57:37 +0000 Subject: [PATCH 09/21] Update api compatibility test Signed-off-by: Yong Tang --- tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 9597dd7684e..dc0be6d24ac 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1872,6 +1872,10 @@ tf_module { name: "segment_sum" argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "select_v2" + argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "self_adjoint_eig" argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " From ab4f60df26bba9163d7cb1b99fb37ab71a09b70c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 23 Sep 2018 18:37:55 +0000 Subject: [PATCH 10/21] Update api_def Signed-off-by: Yong Tang --- tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt new file mode 100644 index 00000000000..e567206d913 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "SelectV2" +} From 7a9fd72092d1d503c677f0492a08056f6f2ea84d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 8 Oct 2018 15:47:55 +0000 Subject: [PATCH 11/21] Add additional test case for tf.where_v2, based on review feedback Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/where_op_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py index fd0a719ad00..0e9a40ac6bf 100644 --- a/tensorflow/python/kernel_tests/where_op_test.py +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -230,6 +230,15 @@ class WhereOpTest(test.TestCase): constant_op.constant([True], dtype=dtypes.bool), x, y)) self.assertAllEqual(tf_val, np_val) + def testV2PredBroadcasting(self): + pred = np.array([1, 0, 0]).reshape((3, 1)) + x = np.random.randn(3, 4) + y = np.random.randn(3, 4) + np_val = np.where(pred, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate(array_ops.where_v2(pred, x, y)) + self.assertAllClose(tf_val, np_val) + def testBatchSelect(self): x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192] c_mat = np.array([[False] * 192, [True] * 192] * 8192) # [16384, 192] From fec417da16cae270528fdf895c8b9e11fd72934a Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 6 Nov 2018 18:24:46 +0000 Subject: [PATCH 12/21] Fix broken tests Signed-off-by: Yong Tang --- tensorflow/tools/api/golden/v2/tensorflow.pbtxt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 7c865bb0022..43bcbcd312b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -1188,6 +1188,10 @@ tf_module { name: "searchsorted" argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"\", \'None\'], " } + member_method { + name: "select_v2" + argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "sequence_mask" argspec: "args=[\'lengths\', \'maxlen\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\'], " From e1f2b7c7252c8bb3231f724cfdf794a26caedd1c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 6 Dec 2018 17:00:15 +0000 Subject: [PATCH 13/21] Expose where_v2 as v1=["where_v2"], v2=["where"] Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index fd327e049c7..1e7b575a3aa 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -3177,6 +3177,7 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") +@tf_export(v1=["where_v2"], v2=["where"]) def where_v2(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. If both `x` and `y` are None, then this operation returns the coordinates of From ad970c30b35ce8cc97699e5c8e95db4df6f5cb6b Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 6 Dec 2018 19:43:47 +0000 Subject: [PATCH 14/21] Add deprecation to tf.where Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 1e7b575a3aa..259d108dacd 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -3124,6 +3124,10 @@ def squeeze_v2(input, axis=None, name=None): @tf_export("where") +@deprecation.deprecated( + date=None, + instructions="Use tf.where in 2.0, " + "which has the same broadcast rule as np.where") @dispatch.add_dispatch_support def where(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. From 4059b951a0fd3136d238e57c52a68fbe09b9df0b Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 6 Dec 2018 19:59:08 +0000 Subject: [PATCH 15/21] Hide select_v2 API Signed-off-by: Yong Tang --- tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt diff --git a/tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt new file mode 100644 index 00000000000..bf57ed7164d --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SelectV2" + visibility: HIDDEN +} From bc3538f7e22c6ab3ca1615888907a7cfa8f34d6b Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 6 Dec 2018 19:59:22 +0000 Subject: [PATCH 16/21] Update api goldens Signed-off-by: Yong Tang --- tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 8 ++++---- tensorflow/tools/api/golden/v2/tensorflow.pbtxt | 4 ---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index eabf0a41acf..437a3b3f88d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1912,10 +1912,6 @@ tf_module { name: "segment_sum" argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "select_v2" - argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "self_adjoint_eig" argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -2340,6 +2336,10 @@ tf_module { name: "where" argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } + member_method { + name: "where_v2" + argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } member_method { name: "while_loop" argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 847c6569df1..7b1c96c2e87 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -892,10 +892,6 @@ tf_module { name: "searchsorted" argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"\", \'None\'], " } - member_method { - name: "select_v2" - argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "sequence_mask" argspec: "args=[\'lengths\', \'maxlen\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\'], " From 8d198f17e65ac59e3545d5d8eb23797ad2b33813 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 7 Dec 2018 20:02:50 +0000 Subject: [PATCH 17/21] Update API to use @tf_export(v1=["where"]) for legacy `where`, and @tf_export("where", v1=["where_v2"]) for new `where` Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 259d108dacd..413e4d5a653 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -3123,7 +3123,7 @@ def squeeze_v2(input, axis=None, name=None): return squeeze(input, axis, name) -@tf_export("where") +@tf_export(v1=["where"]) @deprecation.deprecated( date=None, instructions="Use tf.where in 2.0, " @@ -3181,7 +3181,7 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") -@tf_export(v1=["where_v2"], v2=["where"]) +@tf_export("where", v1=["where_v2"]) def where_v2(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. If both `x` and `y` are None, then this operation returns the coordinates of From 70e39f214e48e01b32493440f220290a91979a18 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 30 Mar 2019 00:24:09 +0000 Subject: [PATCH 18/21] Update api compat and tf_upgrade_v2 Signed-off-by: Yong Tang --- tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt | 4 ++++ tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt | 4 ++++ tensorflow/tools/compatibility/renames_v2.py | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index df784767755..14810152a2f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3268,6 +3268,10 @@ tf_module { name: "Select" argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SelectV2" + argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SelfAdjointEig" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index df784767755..14810152a2f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3268,6 +3268,10 @@ tf_module { name: "Select" argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SelectV2" + argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SelfAdjointEig" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index f9ff6e2f668..6e44ed07cd2 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -1515,6 +1515,10 @@ renames = { 'tf.compat.v1.variance_scaling_initializer', 'tf.verify_tensor_all_finite': 'tf.compat.v1.verify_tensor_all_finite', + 'tf.where': + 'tf.compat.v1.where', + 'tf.where_v2': + 'tf.where', 'tf.wrap_function': 'tf.compat.v1.wrap_function', 'tf.write_file': From d9b98c4d16d4f4ef4165f15605f4c4d8b30c43bf Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 1 Apr 2019 01:08:53 +0000 Subject: [PATCH 19/21] Rename tf.where_v2 to tf.compat.v2.where, as rename script has to work in v1 and v2 so explicitly specifying tf.compat.v2.where (from review comment). Signed-off-by: Yong Tang --- tensorflow/tools/compatibility/renames_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index 6e44ed07cd2..aab14379ed7 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -1518,7 +1518,7 @@ renames = { 'tf.where': 'tf.compat.v1.where', 'tf.where_v2': - 'tf.where', + 'tf.compat.v2.where', 'tf.wrap_function': 'tf.compat.v1.wrap_function', 'tf.write_file': From e75409c2fe38b85ab858a60f470430eb76eb8556 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 2 May 2019 21:43:46 +0000 Subject: [PATCH 20/21] Rename elem_bcast to then_else_bcast, and remove duplicate template specification based on review comment. Signed-off-by: Yong Tang --- tensorflow/core/kernels/cwise_op_select.cc | 48 ++-------------------- tensorflow/core/kernels/cwise_ops.h | 6 ++- 2 files changed, 9 insertions(+), 45 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 749e7ea792b..e6eb2d9c6d4 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -162,15 +162,15 @@ class SelectV2Op : public OpKernel { // 2-ary broadcast. // Combine `then` and `else`. - BCast elem_bcast(BCast::FromShape(then->shape()), - BCast::FromShape(else_->shape()), false); - OP_REQUIRES(ctx, elem_bcast.IsValid(), + BCast then_else_bcast(BCast::FromShape(then->shape()), + BCast::FromShape(else_->shape()), false); + OP_REQUIRES(ctx, then_else_bcast.IsValid(), errors::InvalidArgument( "then ", then->shape().DebugString(), " and else ", else_->shape().DebugString(), " must be broadcastable")); // Combine `cond` with `then` and `else`. BCast bcast(BCast::FromShape(cond->shape()), - BCast::FromShape(BCast::ToShape(elem_bcast.output_shape())), + BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())), false); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( @@ -443,46 +443,6 @@ struct BatchSelectFunctor { d.parallelFor(batch, cost, work); } }; - -template -struct BCastSelectFunctor { - void operator()(const CPUDevice& d, - typename TTypes::Tensor output_tensor, - typename TTypes::ConstTensor cond_tensor, - typename TTypes::ConstTensor then_tensor, - typename TTypes::ConstTensor else_tensor, - typename Eigen::array cond_bcast, - typename Eigen::array then_bcast, - typename Eigen::array else_bcast) { - output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) - .select(then_tensor.broadcast(then_bcast), - else_tensor.broadcast(else_bcast)); - } -}; - -#ifdef TENSORFLOW_USE_SYCL -template -struct BatchSelectFunctor - : BatchSelectFunctorBase {}; - -template -struct BCastSelectFunctor { - void operator()(const SYCLDevice& d, - typename TTypes::Tensor output_tensor, - typename TTypes::ConstTensor cond_tensor, - typename TTypes::ConstTensor then_tensor, - typename TTypes::ConstTensor else_tensor, - typename Eigen::array cond_bcast, - typename Eigen::array then_bcast, - typename Eigen::array else_bcast) { - output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) - .select(then_tensor.broadcast(then_bcast), - else_tensor.broadcast(else_bcast)); - } -}; - -#endif // TENSORFLOW_USE_SYCL - } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index cc8eba69555..610ff07a4dc 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1217,7 +1217,11 @@ struct BCastSelectFunctor { typename TTypes::ConstTensor else_tensor, typename Eigen::array cond_bcast, typename Eigen::array then_bcast, - typename Eigen::array else_bcast); + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } }; } // end namespace functor From 33cd7b88acefb737e45634b38c26c55783deafdf Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 2 May 2019 23:39:59 +0000 Subject: [PATCH 21/21] Fix GPU build failure Signed-off-by: Yong Tang --- tensorflow/core/kernels/cwise_op_select.cc | 32 ++++++++++++++++++++++ tensorflow/core/kernels/cwise_ops.h | 6 +--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index e6eb2d9c6d4..402d24d5b5b 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -443,6 +443,38 @@ struct BatchSelectFunctor { d.parallelFor(batch, cost, work); } }; + +template +struct BCastSelectFunctorBase { + void operator()(const Device& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + +template +struct BCastSelectFunctor + : BCastSelectFunctorBase {}; + +#ifdef TENSORFLOW_USE_SYCL +template +struct BatchSelectFunctor + : BatchSelectFunctorBase {}; + +template +struct BCastSelectFunctor + : BCastSelectFunctorBase {}; + +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 610ff07a4dc..cc8eba69555 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1217,11 +1217,7 @@ struct BCastSelectFunctor { typename TTypes::ConstTensor else_tensor, typename Eigen::array cond_bcast, typename Eigen::array then_bcast, - typename Eigen::array else_bcast) { - output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) - .select(then_tensor.broadcast(then_bcast), - else_tensor.broadcast(else_bcast)); - } + typename Eigen::array else_bcast); }; } // end namespace functor