diff --git a/tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt new file mode 100644 index 00000000000..e567206d913 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "SelectV2" +} diff --git a/tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt new file mode 100644 index 00000000000..bf57ed7164d --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SelectV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SelectV2" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc index 303d8e47913..ba872db2172 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc @@ -23,6 +23,22 @@ limitations under the License. namespace tensorflow { namespace functor { +template +struct BCastSelectFunctor { + void operator()(const GPUDevice& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + template struct SelectFunctor { void operator()(const GPUDevice& d, typename TTypes::Flat out, @@ -89,10 +105,15 @@ struct BatchSelectFunctor { } }; -#define SELECT_FUNCTOR(T) \ - template struct SelectFunctor; \ - template struct SelectScalarFunctor; \ - template struct BatchSelectFunctor; +#define SELECT_FUNCTOR(T) \ + template struct SelectFunctor; \ + template struct SelectScalarFunctor; \ + template struct BatchSelectFunctor; \ + template struct BCastSelectFunctor; \ + template struct BCastSelectFunctor; \ + template struct BCastSelectFunctor; \ + template struct BCastSelectFunctor; \ + template struct BCastSelectFunctor; SELECT_FUNCTOR(bool); SELECT_FUNCTOR(Eigen::half); diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 3b51563ca28..c85c9d0599f 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -143,21 +143,141 @@ class SelectOp : public OpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); }; +template +class SelectV2Op : public OpKernel { + public: + explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {} -#define REGISTER_SELECT(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_CPU).TypeConstraint("T"), \ - SelectOp); + void Compute(OpKernelContext* ctx) override { + const Tensor* cond; + const Tensor* then; + const Tensor* else_; + OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); + OP_REQUIRES_OK(ctx, ctx->input("t", &then)); + OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); + + // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), + // This matches the behavior of numpy. + // TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple + // 2-ary broadcast. + + // Combine `then` and `else`. + BCast then_else_bcast(BCast::FromShape(then->shape()), + BCast::FromShape(else_->shape()), false); + OP_REQUIRES(ctx, then_else_bcast.IsValid(), + errors::InvalidArgument( + "then ", then->shape().DebugString(), " and else ", + else_->shape().DebugString(), " must be broadcastable")); + // Combine `cond` with `then` and `else`. + BCast bcast( + BCast::FromShape(cond->shape()), + BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())), + false); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "condition ", cond->shape().DebugString(), ", then ", + then->shape().DebugString(), ", and else ", + else_->shape().DebugString(), " must be broadcastable")); + + // Broadcast `cond`, `then` and `else` to combined shape, + // in order to obtain the reshape. + BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(cond->shape()), false); + BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(then->shape()), false); + BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), + BCast::FromShape(else_->shape()), false); + OP_REQUIRES( + ctx, + cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(), + errors::InvalidArgument("condition ", cond->shape().DebugString(), + ", then ", then->shape().DebugString(), + ", and else ", else_->shape().DebugString(), + " must be broadcastable")); + + // Combined shape should be the final shape. + OP_REQUIRES( + ctx, + cond_bcast.output_shape() == bcast.output_shape() && + then_bcast.output_shape() == bcast.output_shape() && + else_bcast.output_shape() == bcast.output_shape(), + errors::InvalidArgument("condition ", cond->shape().DebugString(), + ", then ", then->shape().DebugString(), + ", and else ", else_->shape().DebugString(), + " must be broadcastable to the same shape")); + + Tensor* output = nullptr; + const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {"t", "e"}, "output", output_shape, &output)); + + if (output->NumElements() == 0) { + return; + } + +#define HANDLE_DIM(NDIMS) \ + { \ + functor::BCastSelectFunctor func; \ + func(ctx->eigen_device(), \ + output->shaped(bcast.result_shape()), \ + cond->template shaped(cond_bcast.y_reshape()), \ + then->template shaped(then_bcast.y_reshape()), \ + else_->template shaped(else_bcast.y_reshape()), \ + BCast::ToIndexArray(cond_bcast.y_bcast()), \ + BCast::ToIndexArray(then_bcast.y_bcast()), \ + BCast::ToIndexArray(else_bcast.y_bcast())); \ + } + + const int ndims = static_cast(bcast.result_shape().size()); + switch (ndims) { + case 1: + HANDLE_DIM(1); + break; + case 2: + HANDLE_DIM(2); + break; + case 3: + HANDLE_DIM(3); + break; + case 4: + HANDLE_DIM(4); + break; + case 5: + HANDLE_DIM(5); + break; + default: + ctx->SetStatus(errors::Unimplemented( + "Broadcast between ", ctx->input(0).shape().DebugString(), " and ", + ctx->input(1).shape().DebugString(), " is not supported yet.")); + break; + } + return; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op); +}; + +#define REGISTER_SELECT(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_CPU).TypeConstraint("T"), \ + SelectOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_CPU).TypeConstraint("T"), \ + SelectV2Op); TF_CALL_ALL_TYPES(REGISTER_SELECT); #if GOOGLE_CUDA // Registration of the GPU implementations. -#define REGISTER_SELECT_GPU(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_GPU).TypeConstraint("T"), \ - SelectOp); +#define REGISTER_SELECT_GPU(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_GPU).TypeConstraint("T"), \ + SelectOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_GPU).TypeConstraint("T"), \ + SelectV2Op); REGISTER_SELECT_GPU(bool); REGISTER_SELECT_GPU(Eigen::half); @@ -174,9 +294,12 @@ REGISTER_SELECT_GPU(complex128); #ifdef TENSORFLOW_USE_SYCL // Registration of the SYCL implementations. -#define REGISTER_SELECT_SYCL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Select").Device(DEVICE_SYCL).TypeConstraint("T"), \ +#define REGISTER_SELECT_SYCL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_SYCL).TypeConstraint("T"), \ + SelectOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint("T"), \ SelectOp); REGISTER_SELECT_SYCL(float); @@ -324,10 +447,35 @@ struct BatchSelectFunctor { } }; +template +struct BCastSelectFunctorBase { + void operator()(const Device& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast) { + output_tensor.device(d) = cond_tensor.broadcast(cond_bcast) + .select(then_tensor.broadcast(then_bcast), + else_tensor.broadcast(else_bcast)); + } +}; + +template +struct BCastSelectFunctor + : BCastSelectFunctorBase {}; + #ifdef TENSORFLOW_USE_SYCL template struct BatchSelectFunctor : BatchSelectFunctorBase {}; + +template +struct BCastSelectFunctor + : BCastSelectFunctorBase {}; + #endif // TENSORFLOW_USE_SYCL } // namespace functor diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 12968cf85ac..48297bc2848 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -1208,6 +1208,18 @@ struct BatchSelectFunctor { typename TTypes::ConstMatrix else_flat_outer_dims); }; +template +struct BCastSelectFunctor { + void operator()(const Device& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast); +}; + } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 29a93753e7d..3ff9bc09853 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -828,6 +828,57 @@ REGISTER_OP("Select") return Status::OK(); }); +REGISTER_OP("SelectV2") + .Input("condition: bool") + .Input("t: T") + .Input("e: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + auto* handle_data_1 = c->input_handle_shapes_and_types(1); + auto* handle_data_2 = c->input_handle_shapes_and_types(2); + // Merge handle shape and dtype if applicable. + if (handle_data_1 != nullptr && handle_data_2 != nullptr) { + const auto size = handle_data_1->size(); + std::vector merged_handle_data(size); + if (size != handle_data_2->size()) { + return errors::InvalidArgument( + "Trying to merge handles pointing to different numbers of " + "tensors."); + } + + for (int i = 0; i < size; ++i) { + const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i]; + const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i]; + if (s1.dtype != s2.dtype) { + // TODO(apassos) resolve this in the manner of b/32476923 + return errors::InvalidArgument( + "Trying to merge handles pointing to different dtypes."); + } + merged_handle_data[i].dtype = s1.dtype; + TF_RETURN_IF_ERROR( + c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape)); + } + + c->set_output_handle_shapes_and_types(0, merged_handle_data); + } + + // The inputs 'cond', 'then', and 'else' must be broadcastable. + // TODO (yongtang): Consolidate 3-ary broadcast instead of + // multiple 2-ary broadcast. + ShapeHandle cond = c->input(0); + ShapeHandle then = c->input(1); + ShapeHandle else_ = c->input(2); + ShapeHandle other; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other)); + ShapeHandle output; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output)); + c->set_output(0, output); + return Status::OK(); + }); + // -------------------------------------------------------------------------- REGISTER_OP("MatMul") diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 1cb9cfa7479..39ae0ac0317 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -318,28 +318,38 @@ class LogicalOpTest(test.TestCase): class SelectOpTest(test.TestCase): - def _compare(self, c, x, y, use_gpu): + def _compare(self, fn, c, x, y, use_gpu): np_ans = np.where(c, x, y) with test_util.device(use_gpu=use_gpu): - out = array_ops.where(c, x, y) + out = fn(c, x, y) tf_ans = self.evaluate(out) self.assertAllEqual(np_ans, tf_ans) self.assertShapeEqual(np_ans, out) - def _compareGradientX(self, c, x, y, numeric_gradient_type=None): + def _compareGradientX(self, + fn, + c, + x, + y, + numeric_gradient_type=None, + x_init_value=None): with self.cached_session(): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) - out = array_ops.where(c, inx, iny) + out = fn(c, inx, iny) s = list(np.shape(c)) + if x_init_value is None: + x_init_value = x + if x.shape != y.shape: + x_init_value = np.broadcast_to(y, x.shape) jacob_t, jacob_n = gradient_checker.compute_gradient( - inx, s, out, s, x_init_value=x) + inx, s, out, s, x_init_value=x_init_value) if numeric_gradient_type is not None: xf = x.astype(numeric_gradient_type) yf = y.astype(numeric_gradient_type) inxf = ops.convert_to_tensor(xf) inyf = ops.convert_to_tensor(yf) - outf = array_ops.where(c, inxf, inyf) + outf = fn(c, inxf, inyf) _, jacob_n = gradient_checker.compute_gradient( inxf, s, outf, s, x_init_value=xf) jacob_n = jacob_n.astype(x.dtype) @@ -350,20 +360,20 @@ class SelectOpTest(test.TestCase): elif x.dtype == np.float64: self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) - def _compareGradientY(self, c, x, y, numeric_gradient_type=None): + def _compareGradientY(self, fn, c, x, y, numeric_gradient_type=None): with self.cached_session(): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) - out = array_ops.where(c, inx, iny) + out = fn(c, inx, iny) s = list(np.shape(c)) jacob_t, jacob_n = gradient_checker.compute_gradient( - iny, s, out, s, x_init_value=y, delta=1.0) + iny, s, out, s, x_init_value=x, delta=1.0) if numeric_gradient_type is not None: xf = x.astype(numeric_gradient_type) yf = y.astype(numeric_gradient_type) inxf = ops.convert_to_tensor(xf) inyf = ops.convert_to_tensor(yf) - outf = array_ops.where(c, inxf, inyf) + outf = fn(c, inxf, inyf) _, jacob_n = gradient_checker.compute_gradient( inyf, s, outf, s, x_init_value=yf) jacob_n = jacob_n.astype(x.dtype) @@ -374,7 +384,7 @@ class SelectOpTest(test.TestCase): elif x.dtype == np.float64: self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) - def testScalar(self): + def _testScalar(self, fn): c = True x = np.random.rand(1, 3, 2) * 100 y = np.random.rand(1, 3, 2) * 100 @@ -384,11 +394,58 @@ class SelectOpTest(test.TestCase): ]: xt = x.astype(t) yt = y.astype(t) - self._compare(c, xt, yt, use_gpu=False) + self._compare(fn, c, xt, yt, use_gpu=False) if t in [np.float16, np.float32, np.float64]: - self._compare(c, xt, yt, use_gpu=True) + self._compare(fn, c, xt, yt, use_gpu=True) - def testBasic(self): + def testScalar(self): + self._testScalar(array_ops.where) + self._testScalar(array_ops.where_v2) + + def _testScalarBroadcast(self, fn, c, x, y): + for t in [ + np.float16, np.float32, np.float64, np.int32, np.int64, np.complex64, + np.complex128 + ]: + xt = x.astype(t) + yt = y.astype(t) + self._compare(fn, c, xt, yt, use_gpu=False) + if t in [np.float16, np.float32, np.float64]: + self._compare(fn, c, xt, yt, use_gpu=True) + + def testScalarBroadcast(self): + c = True + # where_v2 only + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 1) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 3, 1) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 2) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 2) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(3, 2) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + + def _testBasic(self, fn): c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) x = np.random.rand(1, 3, 2) * 100 y = np.random.rand(1, 3, 2) * 100 @@ -398,12 +455,62 @@ class SelectOpTest(test.TestCase): ]: xt = x.astype(t) yt = y.astype(t) - self._compare(c, xt, yt, use_gpu=False) + self._compare(fn, c, xt, yt, use_gpu=False) if t in [np.float16, np.float32, np.float64]: - self._compare(c, xt, yt, use_gpu=True) + self._compare(fn, c, xt, yt, use_gpu=True) - @test_util.run_deprecated_v1 - def testGradients(self): + def testBasic(self): + self._testBasic(array_ops.where) + self._testBasic(array_ops.where_v2) + + def _testBasicBroadcast(self, fn, c, x, y): + for t in [ + np.float16, np.float32, np.float64, np.int32, np.int64, np.complex64, + np.complex128 + ]: + xt = x.astype(t) + yt = y.astype(t) + self._compare(fn, c, xt, yt, use_gpu=False) + if t in [np.float16, np.float32, np.float64]: + self._compare(fn, c, xt, yt, use_gpu=True) + + def testBasicBroadcast(self): + c0 = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) + c1 = np.random.randint(0, 2, 2).astype(np.bool).reshape(1, 1, 2) + c2 = np.random.randint(0, 2, 3).astype(np.bool).reshape(1, 3, 1) + c3 = np.random.randint(0, 2, 1).astype(np.bool).reshape(1, 1, 1) + for c in [c0, c1, c2, c3]: + # where_v2 only + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 1) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 3, 1) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 2) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 2) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(3, 2) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + + def _testGradients(self, fn): c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) x = np.random.rand(1, 3, 2) * 100 y = np.random.rand(1, 3, 2) * 100 @@ -416,14 +523,45 @@ class SelectOpTest(test.TestCase): # care is taken with choosing the inputs and the delta. This is # a weaker check (in particular, it does not test the op itself, # only its gradient), but it's much better than nothing. - self._compareGradientX(c, xt, yt, np.float) - self._compareGradientY(c, xt, yt, np.float) + self._compareGradientX(fn, c, xt, yt, np.float) + self._compareGradientY(fn, c, xt, yt, np.float) else: - self._compareGradientX(c, xt, yt) - self._compareGradientY(c, xt, yt) + self._compareGradientX(fn, c, xt, yt) + self._compareGradientY(fn, c, xt, yt) @test_util.run_deprecated_v1 - def testShapeMismatch(self): + def testGradients(self): + self._testGradients(array_ops.where) + self._testGradients(array_ops.where_v2) + + @test_util.run_deprecated_v1 + def testGradientsBroadcast(self): + c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) + for t in [np.float32, np.float64]: + # where_v2 only + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 1) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 3, 1) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 2) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 2) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(3, 2) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + + def _testShapeMismatch(self, fn): c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) x = np.random.rand(1, 3, 2) * 100 y = np.random.rand(2, 5, 3) * 100 @@ -434,10 +572,14 @@ class SelectOpTest(test.TestCase): xt = x.astype(t) yt = y.astype(t) with self.assertRaises(ValueError): - array_ops.where(c, xt, yt) + fn(c, xt, yt) @test_util.run_deprecated_v1 - def testEmptyTensor(self): + def testShapeMismatch(self): + self._testShapeMismatch(array_ops.where) + self._testShapeMismatch(array_ops.where_v2) + + def _testEmptyTensor(self, fn): c = np.random.randint(0, 3, 0).astype(np.bool).reshape(1, 3, 0) x = np.random.rand(1, 3, 0) * 100 y = np.random.rand(1, 3, 0) * 100 @@ -445,20 +587,29 @@ class SelectOpTest(test.TestCase): with self.cached_session(): xt = x.astype(np.float32) yt = y.astype(np.float32) - z = array_ops.where(c, xt, yt).eval() + z = fn(c, xt, yt).eval() self.assertAllEqual(z_expected, z) @test_util.run_deprecated_v1 - def testNan(self): - """Verify that nans don't propagate where they shouldn't.""" + def testEmptyTensor(self): + self._testEmptyTensor(array_ops.where) + self._testEmptyTensor(array_ops.where_v2) + + def _testNan(self, fn): with self.cached_session(): for c in False, True: for a in 7.0, np.nan: for b in 5.0, np.nan: - x = array_ops.where(c, a, b).eval() + x = fn(c, a, b).eval() y = a if c else b self.assertEqual(np.isnan(x), np.isnan(y)) + @test_util.run_deprecated_v1 + def testNan(self): + """Verify that nans don't propagate where they shouldn't.""" + self._testNan(array_ops.where) + self._testNan(array_ops.where_v2) + class BatchSelectOpTest(test.TestCase): """Test broadcasting of Select when 'c' is a vec and 't' &'e' are rank2+.""" diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py index 56c13904113..00d916c06eb 100644 --- a/tensorflow/python/kernel_tests/where_op_test.py +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -37,10 +37,10 @@ from tensorflow.python.platform import test class WhereOpTest(test.TestCase): - def _testWhere(self, x, truth, expected_err_re=None): + def _testWhere(self, x, truth, expected_err_re=None, fn=array_ops.where): with self.cached_session(use_gpu=True): - ans = array_ops.where(x) - self.assertEqual([None, x.ndim], ans.get_shape().as_list()) + ans = fn(x) + self.assertTrue(ans.get_shape().is_compatible_with([None, x.ndim])) if expected_err_re is None: tf_ans = self.evaluate(ans) self.assertAllClose(tf_ans, truth, atol=1e-10) @@ -48,44 +48,40 @@ class WhereOpTest(test.TestCase): with self.assertRaisesOpError(expected_err_re): self.evaluate(ans) - def testWrongNumbers(self): + def _testWrongNumbers(self, fn=array_ops.where): with self.session(use_gpu=True): with self.assertRaises(ValueError): - array_ops.where([False, True], [1, 2], None) + fn([False, True], [1, 2], None) with self.assertRaises(ValueError): - array_ops.where([False, True], None, [1, 2]) + fn([False, True], None, [1, 2]) - @test_util.run_deprecated_v1 - def testBasicVec(self): + def _testBasicVec(self, fn=array_ops.where): x = np.asarray([True, False]) truth = np.asarray([[0]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) x = np.asarray([False, True, False]) truth = np.asarray([[1]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) x = np.asarray([False, False, True, False, True]) truth = np.asarray([[2], [4]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - @test_util.run_deprecated_v1 - def testRandomVec(self): + def _testRandomVec(self, fn=array_ops.where): x = np.random.rand(1000000) > 0.5 truth = np.vstack([np.where(x)[0].astype(np.int64)]).T - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - @test_util.run_deprecated_v1 - def testBasicMat(self): + def _testBasicMat(self, fn=array_ops.where): x = np.asarray([[True, False], [True, False]]) # Ensure RowMajor mode truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - @test_util.run_deprecated_v1 - def testBasic3Tensor(self): + def _testBasic3Tensor(self, fn=array_ops.where): x = np.asarray([[[True, False], [True, False]], [[False, True], [False, True]], [[False, False], [False, True]]]) @@ -94,15 +90,41 @@ class WhereOpTest(test.TestCase): truth = np.asarray( [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64) - self._testWhere(x, truth) + self._testWhere(x, truth, None, fn) - def _testRandom(self, dtype, expected_err_re=None): + def _testRandom(self, dtype, expected_err_re=None, fn=array_ops.where): shape = [127, 33, 53] x = np.random.randn(*shape) + 1j * np.random.randn(*shape) x = (np.random.randn(*shape) > 0).astype(dtype) truth = np.where(np.abs(x) > 0) # Tuples of indices by axis. truth = np.vstack(truth).T # Convert to [num_true, indices]. - self._testWhere(x, truth, expected_err_re) + self._testWhere(x, truth, expected_err_re, fn) + + def _testThreeArgument(self, fn=array_ops.where): + x = np.array([[-2, 3, -1], [1, -3, -3]]) + np_val = np.where(x > 0, x * x, -x) + with self.test_session(use_gpu=True): + tf_val = self.evaluate(fn(constant_op.constant(x) > 0, x * x, -x)) + self.assertAllEqual(tf_val, np_val) + + def testWrongNumbers(self): + self._testWrongNumbers() + + @test_util.run_deprecated_v1 + def testBasicVec(self): + self._testBasicVec() + + @test_util.run_deprecated_v1 + def testRandomVec(self): + self._testRandomVec() + + @test_util.run_deprecated_v1 + def testBasicMat(self): + self._testBasicMat() + + @test_util.run_deprecated_v1 + def testBasic3Tensor(self): + self._testBasic3Tensor() @test_util.run_deprecated_v1 def testRandomBool(self): @@ -146,12 +168,95 @@ class WhereOpTest(test.TestCase): @test_util.run_deprecated_v1 def testThreeArgument(self): - x = np.array([[-2, 3, -1], [1, -3, -3]]) - np_val = np.where(x > 0, x * x, -x) - with self.session(use_gpu=True): - tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval() + self._testThreeArgument() + + def testV2WrongNumbers(self): + self._testWrongNumbers(array_ops.where_v2) + + def testV2BasicVec(self): + self._testBasicVec(array_ops.where_v2) + + def testV2RandomVec(self): + self._testRandomVec(array_ops.where_v2) + + def testV2BasicMat(self): + self._testBasicMat(array_ops.where_v2) + + def testV2Basic3Tensor(self): + self._testBasic3Tensor(array_ops.where_v2) + + def testV2RandomBool(self): + self._testRandom(np.bool, None, array_ops.where_v2) + + def testV2RandomInt32(self): + self._testRandom(np.int32, None, array_ops.where_v2) + + def testV2RandomInt64(self): + self._testRandom(np.int64, None, array_ops.where_v2) + + def testV2RandomFloat(self): + self._testRandom(np.float32, None, array_ops.where_v2) + + def testV2RandomDouble(self): + self._testRandom(np.float64, None, array_ops.where_v2) + + def testV2RandomComplex64(self): + self._testRandom(np.complex64, None, array_ops.where_v2) + + def testV2RandomComplex128(self): + self._testRandom(np.complex128, None, array_ops.where_v2) + + def testV2RandomUint8(self): + self._testRandom(np.uint8, None, array_ops.where_v2) + + def testV2RandomInt8(self): + self._testRandom(np.int8, None, array_ops.where_v2) + + def testV2RandomInt16(self): + self._testRandom(np.int16, None, array_ops.where_v2) + + def testV2ThreeArgument(self): + self._testThreeArgument(array_ops.where_v2) + + def testV2Broadcasting(self): + f = np.random.normal(0, 1, (3, 5, 1, 1)) + x = np.zeros((7, 11)) + y = np.ones((7, 11)) + np_val = np.where(f < 0, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate( + array_ops.where_v2(constant_op.constant(f) < 0, x, y)) self.assertAllEqual(tf_val, np_val) + def testV2ScalarBroadcasting(self): + x = np.zeros((7, 11)) + y = np.ones((7, 11)) + np_val = np.where(True, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate( + array_ops.where_v2( + constant_op.constant(True, dtype=dtypes.bool), x, y)) + self.assertAllEqual(tf_val, np_val) + + def testV2VectorBroadcasting(self): + x = np.zeros(7) + y = np.ones(7) + np_val = np.where([True], x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate( + array_ops.where_v2( + constant_op.constant([True], dtype=dtypes.bool), x, y)) + self.assertAllEqual(tf_val, np_val) + + def testV2PredBroadcasting(self): + pred = np.array([1, 0, 0]).reshape((3, 1)) + x = np.random.randn(3, 4) + y = np.random.randn(3, 4) + np_val = np.where(pred, x, y) + with self.test_session(use_gpu=True): + tf_val = self.evaluate(array_ops.where_v2(pred, x, y)) + self.assertAllClose(tf_val, np_val) + @test_util.run_deprecated_v1 def testBatchSelect(self): x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192] diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 12c9178301b..e31135cd047 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -3214,7 +3214,11 @@ def squeeze_v2(input, axis=None, name=None): return squeeze(input, axis, name) -@tf_export("where") +@tf_export(v1=["where"]) +@deprecation.deprecated( + date=None, + instructions="Use tf.where in 2.0, " + "which has the same broadcast rule as np.where") @dispatch.add_dispatch_support def where(condition, x=None, y=None, name=None): """Return the elements, either from `x` or `y`, depending on the `condition`. @@ -3268,6 +3272,48 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") +@tf_export("where", v1=["where_v2"]) +def where_v2(condition, x=None, y=None, name=None): + """Return the elements, either from `x` or `y`, depending on the `condition`. + + If both `x` and `y` are None, then this operation returns the coordinates of + true elements of `condition`. The coordinates are returned in a 2-D tensor + where the first dimension (rows) represents the number of true elements, and + the second dimension (columns) represents the coordinates of the true + elements. Keep in mind, the shape of the output tensor can vary depending on + how many true values there are in input. Indices are output in row-major + order. + If both non-None, `condition`, `x` and `y` must be broadcastable to the same + shape. + The `condition` tensor acts as a mask that chooses, based on the value at each + element, whether the corresponding element / row in the output should be taken + from `x` (if true) or `y` (if false). + Args: + condition: A `Tensor` of type `bool` + x: A Tensor which is of the same type as `y`, and may be broadcastable with + `condition` and `y`. + y: A Tensor which is of the same type as `x`, and may be broadcastable with + `condition` and `x`. + name: A name of the operation (optional). + + Returns: + A `Tensor` with the same type as `x` and `y`, and shape that + is broadcasted from `condition`, `x`, and `y`, if `x`, `y` are non-None. + A `Tensor` with shape `(num_true, dim_size(condition))`. + Raises: + ValueError: When exactly one of `x` or `y` is non-None. + """ + if x is None and y is None: + with ops.name_scope(name, "Where", [condition]) as name: + condition = ops.convert_to_tensor( + condition, preferred_dtype=dtypes.bool, name="condition") + return gen_array_ops.where(condition=condition, name=name) + elif x is not None and y is not None: + return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name) + else: + raise ValueError("x and y must both be non-None or both be None.") + + # pylint: disable=redefined-builtin @tf_export(v1=["reverse_sequence"]) @deprecation.deprecated_args(None, diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index d848fe49730..79594df029a 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1309,6 +1309,39 @@ def _SelectGrad(op, grad): c, zeros, grad)) +@ops.RegisterGradient("SelectV2") +def _SelectGradV2(op, grad): + c = op.inputs[0] + x = op.inputs[1] + y = op.inputs[2] + zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype) + gx = array_ops.where_v2(c, grad, zeros) + gx_shape = array_ops.shape(gx) + x_shape = array_ops.shape(x) + rankdiff_x = array_ops.rank(gx) - array_ops.rank(x) + # Reduce away broadcasted leading dims. + gx = math_ops.reduce_sum(gx, axis=math_ops.range(rankdiff_x)) + # Reduce but keep x's 1-valued dims which were broadcast. + axis = array_ops.where_v2(gx_shape[rankdiff_x:] > x_shape) + # tf.where returns 2D so squeeze. + axis = array_ops.squeeze(axis) + gx = math_ops.reduce_sum(gx, keepdims=True, axis=axis) + + gy = array_ops.where_v2(c, zeros, grad) + gy_shape = array_ops.shape(gy) + y_shape = array_ops.shape(y) + rankdiff_y = array_ops.rank(gy) - array_ops.rank(y) + # Reduce away broadcasted leading dims. + gy = math_ops.reduce_sum(gy, axis=math_ops.range(rankdiff_y)) + # Reduce but keep y's 1-valued dims which were broadcast. + axis = array_ops.where_v2(gy_shape[rankdiff_y:] > y_shape) + # tf.where returns 2D so squeeze. + axis = array_ops.squeeze(axis) + gy = math_ops.reduce_sum(gy, keepdims=True, axis=axis) + + return (None, gx, gy) + + def _MatMulGradAgainstFirstOnly(op, grad): """Gradient for MatMul, only for the first input.""" t_a = op.get_attr("transpose_a") diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index ae66ee8febd..3212deeefb2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -2436,6 +2436,10 @@ tf_module { name: "where" argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } + member_method { + name: "where_v2" + argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } member_method { name: "while_loop" argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 12e668952bc..9208b8dc384 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3352,6 +3352,10 @@ tf_module { name: "Select" argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SelectV2" + argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SelfAdjointEig" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 12e668952bc..9208b8dc384 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3352,6 +3352,10 @@ tf_module { name: "Select" argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SelectV2" + argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SelfAdjointEig" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index 201925889ef..3c9129b5b6e 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -1543,6 +1543,10 @@ renames = { 'tf.compat.v1.variables_initializer', 'tf.verify_tensor_all_finite': 'tf.compat.v1.verify_tensor_all_finite', + 'tf.where': + 'tf.compat.v1.where', + 'tf.where_v2': + 'tf.compat.v2.where', 'tf.wrap_function': 'tf.compat.v1.wrap_function', 'tf.write_file':