From e87e48dd8c5b84318122b5aa6d98945041c574d4 Mon Sep 17 00:00:00 2001
From: Yong Tang <yong.tang.github@outlook.com>
Date: Fri, 10 May 2019 00:41:06 +0000
Subject: [PATCH] Add broadcasting support for `tf.where`

This is a rework on 15982

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
---
 .../core/kernels/cwise_op_gpu_select.cu.cc    |  29 ++-
 tensorflow/core/kernels/cwise_op_select.cc    | 170 ++++++++++++++++--
 tensorflow/core/kernels/cwise_ops.h           |  12 ++
 tensorflow/core/ops/math_ops.cc               |  51 ++++++
 .../python/kernel_tests/where_op_test.py      | 157 +++++++++++++---
 tensorflow/python/ops/array_ops.py            |  48 ++++-
 6 files changed, 425 insertions(+), 42 deletions(-)

diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
index 303d8e47913..ba872db2172 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
@@ -23,6 +23,22 @@ limitations under the License.
 namespace tensorflow {
 namespace functor {
 
+template <typename T, int NDIMS>
+struct BCastSelectFunctor<GPUDevice, T, NDIMS> {
+  void operator()(const GPUDevice& d,
+                  typename TTypes<T, NDIMS>::Tensor output_tensor,
+                  typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
+                  typename TTypes<T, NDIMS>::ConstTensor then_tensor,
+                  typename TTypes<T, NDIMS>::ConstTensor else_tensor,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
+    output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
+                                  .select(then_tensor.broadcast(then_bcast),
+                                          else_tensor.broadcast(else_bcast));
+  }
+};
+
 template <typename T>
 struct SelectFunctor<GPUDevice, T> {
   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
@@ -89,10 +105,15 @@ struct BatchSelectFunctor<GPUDevice, T> {
   }
 };
 
-#define SELECT_FUNCTOR(T)                            \
-  template struct SelectFunctor<GPUDevice, T>;       \
-  template struct SelectScalarFunctor<GPUDevice, T>; \
-  template struct BatchSelectFunctor<GPUDevice, T>;
+#define SELECT_FUNCTOR(T)                              \
+  template struct SelectFunctor<GPUDevice, T>;         \
+  template struct SelectScalarFunctor<GPUDevice, T>;   \
+  template struct BatchSelectFunctor<GPUDevice, T>;    \
+  template struct BCastSelectFunctor<GPUDevice, T, 1>; \
+  template struct BCastSelectFunctor<GPUDevice, T, 2>; \
+  template struct BCastSelectFunctor<GPUDevice, T, 3>; \
+  template struct BCastSelectFunctor<GPUDevice, T, 4>; \
+  template struct BCastSelectFunctor<GPUDevice, T, 5>;
 
 SELECT_FUNCTOR(bool);
 SELECT_FUNCTOR(Eigen::half);
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 3b51563ca28..c85c9d0599f 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -143,21 +143,141 @@ class SelectOp : public OpKernel {
  private:
   TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
 };
+template <typename Device, typename T>
+class SelectV2Op : public OpKernel {
+ public:
+  explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
 
-#define REGISTER_SELECT(type)                                      \
-  REGISTER_KERNEL_BUILDER(                                         \
-      Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
-      SelectOp<CPUDevice, type>);
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* cond;
+    const Tensor* then;
+    const Tensor* else_;
+    OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
+    OP_REQUIRES_OK(ctx, ctx->input("t", &then));
+    OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
+
+    // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
+    // This matches the behavior of numpy.
+    // TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple
+    // 2-ary broadcast.
+
+    // Combine `then` and `else`.
+    BCast then_else_bcast(BCast::FromShape(then->shape()),
+                          BCast::FromShape(else_->shape()), false);
+    OP_REQUIRES(ctx, then_else_bcast.IsValid(),
+                errors::InvalidArgument(
+                    "then ", then->shape().DebugString(), " and else ",
+                    else_->shape().DebugString(), " must be broadcastable"));
+    // Combine `cond` with `then` and `else`.
+    BCast bcast(
+        BCast::FromShape(cond->shape()),
+        BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())),
+        false);
+    OP_REQUIRES(ctx, bcast.IsValid(),
+                errors::InvalidArgument(
+                    "condition ", cond->shape().DebugString(), ", then ",
+                    then->shape().DebugString(), ", and else ",
+                    else_->shape().DebugString(), " must be broadcastable"));
+
+    // Broadcast `cond`, `then` and `else` to combined shape,
+    // in order to obtain the reshape.
+    BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
+                     BCast::FromShape(cond->shape()), false);
+    BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
+                     BCast::FromShape(then->shape()), false);
+    BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
+                     BCast::FromShape(else_->shape()), false);
+    OP_REQUIRES(
+        ctx,
+        cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(),
+        errors::InvalidArgument("condition ", cond->shape().DebugString(),
+                                ", then ", then->shape().DebugString(),
+                                ", and else ", else_->shape().DebugString(),
+                                " must be broadcastable"));
+
+    // Combined shape should be the final shape.
+    OP_REQUIRES(
+        ctx,
+        cond_bcast.output_shape() == bcast.output_shape() &&
+            then_bcast.output_shape() == bcast.output_shape() &&
+            else_bcast.output_shape() == bcast.output_shape(),
+        errors::InvalidArgument("condition ", cond->shape().DebugString(),
+                                ", then ", then->shape().DebugString(),
+                                ", and else ", else_->shape().DebugString(),
+                                " must be broadcastable to the same shape"));
+
+    Tensor* output = nullptr;
+    const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
+    OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
+                            {"t", "e"}, "output", output_shape, &output));
+
+    if (output->NumElements() == 0) {
+      return;
+    }
+
+#define HANDLE_DIM(NDIMS)                                            \
+  {                                                                  \
+    functor::BCastSelectFunctor<Device, T, NDIMS> func;              \
+    func(ctx->eigen_device<Device>(),                                \
+         output->shaped<T, NDIMS>(bcast.result_shape()),             \
+         cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \
+         then->template shaped<T, NDIMS>(then_bcast.y_reshape()),    \
+         else_->template shaped<T, NDIMS>(else_bcast.y_reshape()),   \
+         BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()),           \
+         BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()),           \
+         BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast()));          \
+  }
+
+    const int ndims = static_cast<int>(bcast.result_shape().size());
+    switch (ndims) {
+      case 1:
+        HANDLE_DIM(1);
+        break;
+      case 2:
+        HANDLE_DIM(2);
+        break;
+      case 3:
+        HANDLE_DIM(3);
+        break;
+      case 4:
+        HANDLE_DIM(4);
+        break;
+      case 5:
+        HANDLE_DIM(5);
+        break;
+      default:
+        ctx->SetStatus(errors::Unimplemented(
+            "Broadcast between ", ctx->input(0).shape().DebugString(), " and ",
+            ctx->input(1).shape().DebugString(), " is not supported yet."));
+        break;
+    }
+    return;
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op);
+};
+
+#define REGISTER_SELECT(type)                                        \
+  REGISTER_KERNEL_BUILDER(                                           \
+      Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
+      SelectOp<CPUDevice, type>);                                    \
+  REGISTER_KERNEL_BUILDER(                                           \
+      Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+      SelectV2Op<CPUDevice, type>);
 
 TF_CALL_ALL_TYPES(REGISTER_SELECT);
 
 #if GOOGLE_CUDA
 
 // Registration of the GPU implementations.
-#define REGISTER_SELECT_GPU(type)                                  \
-  REGISTER_KERNEL_BUILDER(                                         \
-      Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
-      SelectOp<GPUDevice, type>);
+#define REGISTER_SELECT_GPU(type)                                    \
+  REGISTER_KERNEL_BUILDER(                                           \
+      Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"),   \
+      SelectOp<GPUDevice, type>);                                    \
+  REGISTER_KERNEL_BUILDER(                                           \
+      Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+      SelectV2Op<GPUDevice, type>);
 
 REGISTER_SELECT_GPU(bool);
 REGISTER_SELECT_GPU(Eigen::half);
@@ -174,9 +294,12 @@ REGISTER_SELECT_GPU(complex128);
 
 #ifdef TENSORFLOW_USE_SYCL
 // Registration of the SYCL implementations.
-#define REGISTER_SELECT_SYCL(type)                                  \
-  REGISTER_KERNEL_BUILDER(                                          \
-      Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+#define REGISTER_SELECT_SYCL(type)                                    \
+  REGISTER_KERNEL_BUILDER(                                            \
+      Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"),   \
+      SelectOp<SYCLDevice, type>);                                    \
+  REGISTER_KERNEL_BUILDER(                                            \
+      Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
       SelectOp<SYCLDevice, type>);
 
 REGISTER_SELECT_SYCL(float);
@@ -324,10 +447,35 @@ struct BatchSelectFunctor<CPUDevice, T> {
   }
 };
 
+template <typename Device, typename T, int NDIMS>
+struct BCastSelectFunctorBase {
+  void operator()(const Device& d,
+                  typename TTypes<T, NDIMS>::Tensor output_tensor,
+                  typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
+                  typename TTypes<T, NDIMS>::ConstTensor then_tensor,
+                  typename TTypes<T, NDIMS>::ConstTensor else_tensor,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
+    output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
+                                  .select(then_tensor.broadcast(then_bcast),
+                                          else_tensor.broadcast(else_bcast));
+  }
+};
+
+template <typename T, int NDIMS>
+struct BCastSelectFunctor<CPUDevice, T, NDIMS>
+    : BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};
+
 #ifdef TENSORFLOW_USE_SYCL
 template <typename T>
 struct BatchSelectFunctor<SYCLDevice, T>
     : BatchSelectFunctorBase<SYCLDevice, T> {};
+
+template <typename T, int NDIMS>
+struct BCastSelectFunctor<SYCLDevice, T, NDIMS>
+    : BCastSelectFunctorBase<SYCLDevice, T, NDIMS> {};
+
 #endif  // TENSORFLOW_USE_SYCL
 
 }  // namespace functor
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 898147b3661..cc8eba69555 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -1208,6 +1208,18 @@ struct BatchSelectFunctor {
                   typename TTypes<T>::ConstMatrix else_flat_outer_dims);
 };
 
+template <typename Device, typename T, int NDIMS>
+struct BCastSelectFunctor {
+  void operator()(const Device& d,
+                  typename TTypes<T, NDIMS>::Tensor output_tensor,
+                  typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
+                  typename TTypes<T, NDIMS>::ConstTensor then_tensor,
+                  typename TTypes<T, NDIMS>::ConstTensor else_tensor,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
+                  typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
+};
+
 }  // end namespace functor
 }  // end namespace tensorflow
 
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 29a93753e7d..3ff9bc09853 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -828,6 +828,57 @@ REGISTER_OP("Select")
       return Status::OK();
     });
 
+REGISTER_OP("SelectV2")
+    .Input("condition: bool")
+    .Input("t: T")
+    .Input("e: T")
+    .Output("output: T")
+    .Attr("T: type")
+    .SetShapeFn([](InferenceContext* c) {
+      auto* handle_data_1 = c->input_handle_shapes_and_types(1);
+      auto* handle_data_2 = c->input_handle_shapes_and_types(2);
+      // Merge handle shape and dtype if applicable.
+      if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
+        const auto size = handle_data_1->size();
+        std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
+        if (size != handle_data_2->size()) {
+          return errors::InvalidArgument(
+              "Trying to merge handles pointing to different numbers of "
+              "tensors.");
+        }
+
+        for (int i = 0; i < size; ++i) {
+          const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
+          const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
+          if (s1.dtype != s2.dtype) {
+            // TODO(apassos) resolve this in the manner of b/32476923
+            return errors::InvalidArgument(
+                "Trying to merge handles pointing to different dtypes.");
+          }
+          merged_handle_data[i].dtype = s1.dtype;
+          TF_RETURN_IF_ERROR(
+              c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
+        }
+
+        c->set_output_handle_shapes_and_types(0, merged_handle_data);
+      }
+
+      // The inputs 'cond', 'then', and 'else' must be broadcastable.
+      // TODO (yongtang): Consolidate 3-ary broadcast instead of
+      // multiple 2-ary broadcast.
+      ShapeHandle cond = c->input(0);
+      ShapeHandle then = c->input(1);
+      ShapeHandle else_ = c->input(2);
+      ShapeHandle other;
+      TF_RETURN_IF_ERROR(
+          BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other));
+      ShapeHandle output;
+      TF_RETURN_IF_ERROR(
+          BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output));
+      c->set_output(0, output);
+      return Status::OK();
+    });
+
 // --------------------------------------------------------------------------
 
 REGISTER_OP("MatMul")
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index 56c13904113..00d916c06eb 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -37,10 +37,10 @@ from tensorflow.python.platform import test
 
 class WhereOpTest(test.TestCase):
 
-  def _testWhere(self, x, truth, expected_err_re=None):
+  def _testWhere(self, x, truth, expected_err_re=None, fn=array_ops.where):
     with self.cached_session(use_gpu=True):
-      ans = array_ops.where(x)
-      self.assertEqual([None, x.ndim], ans.get_shape().as_list())
+      ans = fn(x)
+      self.assertTrue(ans.get_shape().is_compatible_with([None, x.ndim]))
       if expected_err_re is None:
         tf_ans = self.evaluate(ans)
         self.assertAllClose(tf_ans, truth, atol=1e-10)
@@ -48,44 +48,40 @@ class WhereOpTest(test.TestCase):
         with self.assertRaisesOpError(expected_err_re):
           self.evaluate(ans)
 
-  def testWrongNumbers(self):
+  def _testWrongNumbers(self, fn=array_ops.where):
     with self.session(use_gpu=True):
       with self.assertRaises(ValueError):
-        array_ops.where([False, True], [1, 2], None)
+        fn([False, True], [1, 2], None)
       with self.assertRaises(ValueError):
-        array_ops.where([False, True], None, [1, 2])
+        fn([False, True], None, [1, 2])
 
-  @test_util.run_deprecated_v1
-  def testBasicVec(self):
+  def _testBasicVec(self, fn=array_ops.where):
     x = np.asarray([True, False])
     truth = np.asarray([[0]], dtype=np.int64)
-    self._testWhere(x, truth)
+    self._testWhere(x, truth, None, fn)
 
     x = np.asarray([False, True, False])
     truth = np.asarray([[1]], dtype=np.int64)
-    self._testWhere(x, truth)
+    self._testWhere(x, truth, None, fn)
 
     x = np.asarray([False, False, True, False, True])
     truth = np.asarray([[2], [4]], dtype=np.int64)
-    self._testWhere(x, truth)
+    self._testWhere(x, truth, None, fn)
 
-  @test_util.run_deprecated_v1
-  def testRandomVec(self):
+  def _testRandomVec(self, fn=array_ops.where):
     x = np.random.rand(1000000) > 0.5
     truth = np.vstack([np.where(x)[0].astype(np.int64)]).T
-    self._testWhere(x, truth)
+    self._testWhere(x, truth, None, fn)
 
-  @test_util.run_deprecated_v1
-  def testBasicMat(self):
+  def _testBasicMat(self, fn=array_ops.where):
     x = np.asarray([[True, False], [True, False]])
 
     # Ensure RowMajor mode
     truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64)
 
-    self._testWhere(x, truth)
+    self._testWhere(x, truth, None, fn)
 
-  @test_util.run_deprecated_v1
-  def testBasic3Tensor(self):
+  def _testBasic3Tensor(self, fn=array_ops.where):
     x = np.asarray([[[True, False], [True, False]],
                     [[False, True], [False, True]],
                     [[False, False], [False, True]]])
@@ -94,15 +90,41 @@ class WhereOpTest(test.TestCase):
     truth = np.asarray(
         [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64)
 
-    self._testWhere(x, truth)
+    self._testWhere(x, truth, None, fn)
 
-  def _testRandom(self, dtype, expected_err_re=None):
+  def _testRandom(self, dtype, expected_err_re=None, fn=array_ops.where):
     shape = [127, 33, 53]
     x = np.random.randn(*shape) + 1j * np.random.randn(*shape)
     x = (np.random.randn(*shape) > 0).astype(dtype)
     truth = np.where(np.abs(x) > 0)  # Tuples of indices by axis.
     truth = np.vstack(truth).T  # Convert to [num_true, indices].
-    self._testWhere(x, truth, expected_err_re)
+    self._testWhere(x, truth, expected_err_re, fn)
+
+  def _testThreeArgument(self, fn=array_ops.where):
+    x = np.array([[-2, 3, -1], [1, -3, -3]])
+    np_val = np.where(x > 0, x * x, -x)
+    with self.test_session(use_gpu=True):
+      tf_val = self.evaluate(fn(constant_op.constant(x) > 0, x * x, -x))
+    self.assertAllEqual(tf_val, np_val)
+
+  def testWrongNumbers(self):
+    self._testWrongNumbers()
+
+  @test_util.run_deprecated_v1
+  def testBasicVec(self):
+    self._testBasicVec()
+
+  @test_util.run_deprecated_v1
+  def testRandomVec(self):
+    self._testRandomVec()
+
+  @test_util.run_deprecated_v1
+  def testBasicMat(self):
+    self._testBasicMat()
+
+  @test_util.run_deprecated_v1
+  def testBasic3Tensor(self):
+    self._testBasic3Tensor()
 
   @test_util.run_deprecated_v1
   def testRandomBool(self):
@@ -146,12 +168,95 @@ class WhereOpTest(test.TestCase):
 
   @test_util.run_deprecated_v1
   def testThreeArgument(self):
-    x = np.array([[-2, 3, -1], [1, -3, -3]])
-    np_val = np.where(x > 0, x * x, -x)
-    with self.session(use_gpu=True):
-      tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval()
+    self._testThreeArgument()
+
+  def testV2WrongNumbers(self):
+    self._testWrongNumbers(array_ops.where_v2)
+
+  def testV2BasicVec(self):
+    self._testBasicVec(array_ops.where_v2)
+
+  def testV2RandomVec(self):
+    self._testRandomVec(array_ops.where_v2)
+
+  def testV2BasicMat(self):
+    self._testBasicMat(array_ops.where_v2)
+
+  def testV2Basic3Tensor(self):
+    self._testBasic3Tensor(array_ops.where_v2)
+
+  def testV2RandomBool(self):
+    self._testRandom(np.bool, None, array_ops.where_v2)
+
+  def testV2RandomInt32(self):
+    self._testRandom(np.int32, None, array_ops.where_v2)
+
+  def testV2RandomInt64(self):
+    self._testRandom(np.int64, None, array_ops.where_v2)
+
+  def testV2RandomFloat(self):
+    self._testRandom(np.float32, None, array_ops.where_v2)
+
+  def testV2RandomDouble(self):
+    self._testRandom(np.float64, None, array_ops.where_v2)
+
+  def testV2RandomComplex64(self):
+    self._testRandom(np.complex64, None, array_ops.where_v2)
+
+  def testV2RandomComplex128(self):
+    self._testRandom(np.complex128, None, array_ops.where_v2)
+
+  def testV2RandomUint8(self):
+    self._testRandom(np.uint8, None, array_ops.where_v2)
+
+  def testV2RandomInt8(self):
+    self._testRandom(np.int8, None, array_ops.where_v2)
+
+  def testV2RandomInt16(self):
+    self._testRandom(np.int16, None, array_ops.where_v2)
+
+  def testV2ThreeArgument(self):
+    self._testThreeArgument(array_ops.where_v2)
+
+  def testV2Broadcasting(self):
+    f = np.random.normal(0, 1, (3, 5, 1, 1))
+    x = np.zeros((7, 11))
+    y = np.ones((7, 11))
+    np_val = np.where(f < 0, x, y)
+    with self.test_session(use_gpu=True):
+      tf_val = self.evaluate(
+          array_ops.where_v2(constant_op.constant(f) < 0, x, y))
     self.assertAllEqual(tf_val, np_val)
 
+  def testV2ScalarBroadcasting(self):
+    x = np.zeros((7, 11))
+    y = np.ones((7, 11))
+    np_val = np.where(True, x, y)
+    with self.test_session(use_gpu=True):
+      tf_val = self.evaluate(
+          array_ops.where_v2(
+              constant_op.constant(True, dtype=dtypes.bool), x, y))
+    self.assertAllEqual(tf_val, np_val)
+
+  def testV2VectorBroadcasting(self):
+    x = np.zeros(7)
+    y = np.ones(7)
+    np_val = np.where([True], x, y)
+    with self.test_session(use_gpu=True):
+      tf_val = self.evaluate(
+          array_ops.where_v2(
+              constant_op.constant([True], dtype=dtypes.bool), x, y))
+    self.assertAllEqual(tf_val, np_val)
+
+  def testV2PredBroadcasting(self):
+    pred = np.array([1, 0, 0]).reshape((3, 1))
+    x = np.random.randn(3, 4)
+    y = np.random.randn(3, 4)
+    np_val = np.where(pred, x, y)
+    with self.test_session(use_gpu=True):
+      tf_val = self.evaluate(array_ops.where_v2(pred, x, y))
+    self.assertAllClose(tf_val, np_val)
+
   @test_util.run_deprecated_v1
   def testBatchSelect(self):
     x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192)  # [16384, 192]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 11ef64d3a4f..28c1cb3a772 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -3163,7 +3163,11 @@ def squeeze_v2(input, axis=None, name=None):
   return squeeze(input, axis, name)
 
 
-@tf_export("where")
+@tf_export(v1=["where"])
+@deprecation.deprecated(
+    date=None,
+    instructions="Use tf.where in 2.0, "
+    "which has the same broadcast rule as np.where")
 @dispatch.add_dispatch_support
 def where(condition, x=None, y=None, name=None):
   """Return the elements, either from `x` or `y`, depending on the `condition`.
@@ -3217,6 +3221,48 @@ def where(condition, x=None, y=None, name=None):
     raise ValueError("x and y must both be non-None or both be None.")
 
 
+@tf_export("where", v1=["where_v2"])
+def where_v2(condition, x=None, y=None, name=None):
+  """Return the elements, either from `x` or `y`, depending on the `condition`.
+
+  If both `x` and `y` are None, then this operation returns the coordinates of
+  true elements of `condition`.  The coordinates are returned in a 2-D tensor
+  where the first dimension (rows) represents the number of true elements, and
+  the second dimension (columns) represents the coordinates of the true
+  elements. Keep in mind, the shape of the output tensor can vary depending on
+  how many true values there are in input. Indices are output in row-major
+  order.
+  If both non-None, `condition`, `x` and `y` must be broadcastable to the same
+  shape.
+  The `condition` tensor acts as a mask that chooses, based on the value at each
+  element, whether the corresponding element / row in the output should be taken
+  from `x` (if true) or `y` (if false).
+  Args:
+    condition: A `Tensor` of type `bool`
+    x: A Tensor which is of the same type as `y`, and may be broadcastable with
+      `condition` and `y`.
+    y: A Tensor which is of the same type as `x`, and may be broadcastable with
+      `condition` and `x`.
+    name: A name of the operation (optional).
+
+  Returns:
+    A `Tensor` with the same type as `x` and `y`, and shape that
+      is broadcasted from `condition`, `x`, and `y`, if `x`, `y` are non-None.
+    A `Tensor` with shape `(num_true, dim_size(condition))`.
+  Raises:
+    ValueError: When exactly one of `x` or `y` is non-None.
+  """
+  if x is None and y is None:
+    with ops.name_scope(name, "Where", [condition]) as name:
+      condition = ops.convert_to_tensor(
+          condition, preferred_dtype=dtypes.bool, name="condition")
+      return gen_array_ops.where(condition=condition, name=name)
+  elif x is not None and y is not None:
+    return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name)
+  else:
+    raise ValueError("x and y must both be non-None or both be None.")
+
+
 # pylint: disable=redefined-builtin
 @tf_export(v1=["reverse_sequence"])
 @deprecation.deprecated_args(None,