Merge pull request #15982 from yongtang:9284-tf.where-broadcasting
PiperOrigin-RevId: 246953640
This commit is contained in:
commit
60524c4167
@ -124,7 +124,7 @@ from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d
|
|||||||
from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d
|
from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
_allowed_symbols = ['nest']
|
_allowed_symbols = ['nest', 'broadcast_to', 'where']
|
||||||
_nest_allowed_symbols = [
|
_nest_allowed_symbols = [
|
||||||
'assert_same_structure',
|
'assert_same_structure',
|
||||||
'is_nested',
|
'is_nested',
|
||||||
|
3
tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt
Normal file
3
tensorflow/core/api_def/base_api/api_def_SelectV2.pbtxt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SelectV2"
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SelectV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -23,6 +23,22 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename T, int NDIMS>
|
||||||
|
struct BCastSelectFunctor<GPUDevice, T, NDIMS> {
|
||||||
|
void operator()(const GPUDevice& d,
|
||||||
|
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
||||||
|
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
||||||
|
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
||||||
|
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
||||||
|
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
||||||
|
.select(then_tensor.broadcast(then_bcast),
|
||||||
|
else_tensor.broadcast(else_bcast));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct SelectFunctor<GPUDevice, T> {
|
struct SelectFunctor<GPUDevice, T> {
|
||||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
|
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
|
||||||
@ -89,10 +105,15 @@ struct BatchSelectFunctor<GPUDevice, T> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#define SELECT_FUNCTOR(T) \
|
#define SELECT_FUNCTOR(T) \
|
||||||
template struct SelectFunctor<GPUDevice, T>; \
|
template struct SelectFunctor<GPUDevice, T>; \
|
||||||
template struct SelectScalarFunctor<GPUDevice, T>; \
|
template struct SelectScalarFunctor<GPUDevice, T>; \
|
||||||
template struct BatchSelectFunctor<GPUDevice, T>;
|
template struct BatchSelectFunctor<GPUDevice, T>; \
|
||||||
|
template struct BCastSelectFunctor<GPUDevice, T, 1>; \
|
||||||
|
template struct BCastSelectFunctor<GPUDevice, T, 2>; \
|
||||||
|
template struct BCastSelectFunctor<GPUDevice, T, 3>; \
|
||||||
|
template struct BCastSelectFunctor<GPUDevice, T, 4>; \
|
||||||
|
template struct BCastSelectFunctor<GPUDevice, T, 5>;
|
||||||
|
|
||||||
SELECT_FUNCTOR(bool);
|
SELECT_FUNCTOR(bool);
|
||||||
SELECT_FUNCTOR(Eigen::half);
|
SELECT_FUNCTOR(Eigen::half);
|
||||||
|
@ -143,21 +143,141 @@ class SelectOp : public OpKernel {
|
|||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
|
||||||
};
|
};
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class SelectV2Op : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
#define REGISTER_SELECT(type) \
|
void Compute(OpKernelContext* ctx) override {
|
||||||
REGISTER_KERNEL_BUILDER( \
|
const Tensor* cond;
|
||||||
Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
const Tensor* then;
|
||||||
SelectOp<CPUDevice, type>);
|
const Tensor* else_;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
|
||||||
|
|
||||||
|
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
|
||||||
|
// This matches the behavior of numpy.
|
||||||
|
// TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple
|
||||||
|
// 2-ary broadcast.
|
||||||
|
|
||||||
|
// Combine `then` and `else`.
|
||||||
|
BCast then_else_bcast(BCast::FromShape(then->shape()),
|
||||||
|
BCast::FromShape(else_->shape()), false);
|
||||||
|
OP_REQUIRES(ctx, then_else_bcast.IsValid(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"then ", then->shape().DebugString(), " and else ",
|
||||||
|
else_->shape().DebugString(), " must be broadcastable"));
|
||||||
|
// Combine `cond` with `then` and `else`.
|
||||||
|
BCast bcast(
|
||||||
|
BCast::FromShape(cond->shape()),
|
||||||
|
BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())),
|
||||||
|
false);
|
||||||
|
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"condition ", cond->shape().DebugString(), ", then ",
|
||||||
|
then->shape().DebugString(), ", and else ",
|
||||||
|
else_->shape().DebugString(), " must be broadcastable"));
|
||||||
|
|
||||||
|
// Broadcast `cond`, `then` and `else` to combined shape,
|
||||||
|
// in order to obtain the reshape.
|
||||||
|
BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||||
|
BCast::FromShape(cond->shape()), false);
|
||||||
|
BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||||
|
BCast::FromShape(then->shape()), false);
|
||||||
|
BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||||
|
BCast::FromShape(else_->shape()), false);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(),
|
||||||
|
errors::InvalidArgument("condition ", cond->shape().DebugString(),
|
||||||
|
", then ", then->shape().DebugString(),
|
||||||
|
", and else ", else_->shape().DebugString(),
|
||||||
|
" must be broadcastable"));
|
||||||
|
|
||||||
|
// Combined shape should be the final shape.
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
cond_bcast.output_shape() == bcast.output_shape() &&
|
||||||
|
then_bcast.output_shape() == bcast.output_shape() &&
|
||||||
|
else_bcast.output_shape() == bcast.output_shape(),
|
||||||
|
errors::InvalidArgument("condition ", cond->shape().DebugString(),
|
||||||
|
", then ", then->shape().DebugString(),
|
||||||
|
", and else ", else_->shape().DebugString(),
|
||||||
|
" must be broadcastable to the same shape"));
|
||||||
|
|
||||||
|
Tensor* output = nullptr;
|
||||||
|
const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
|
||||||
|
{"t", "e"}, "output", output_shape, &output));
|
||||||
|
|
||||||
|
if (output->NumElements() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define HANDLE_DIM(NDIMS) \
|
||||||
|
{ \
|
||||||
|
functor::BCastSelectFunctor<Device, T, NDIMS> func; \
|
||||||
|
func(ctx->eigen_device<Device>(), \
|
||||||
|
output->shaped<T, NDIMS>(bcast.result_shape()), \
|
||||||
|
cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \
|
||||||
|
then->template shaped<T, NDIMS>(then_bcast.y_reshape()), \
|
||||||
|
else_->template shaped<T, NDIMS>(else_bcast.y_reshape()), \
|
||||||
|
BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()), \
|
||||||
|
BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()), \
|
||||||
|
BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast())); \
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ndims = static_cast<int>(bcast.result_shape().size());
|
||||||
|
switch (ndims) {
|
||||||
|
case 1:
|
||||||
|
HANDLE_DIM(1);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
HANDLE_DIM(2);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
HANDLE_DIM(3);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
HANDLE_DIM(4);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
HANDLE_DIM(5);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
ctx->SetStatus(errors::Unimplemented(
|
||||||
|
"Broadcast between ", ctx->input(0).shape().DebugString(), " and ",
|
||||||
|
ctx->input(1).shape().DebugString(), " is not supported yet."));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op);
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_SELECT(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||||
|
SelectOp<CPUDevice, type>); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||||
|
SelectV2Op<CPUDevice, type>);
|
||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER_SELECT);
|
TF_CALL_ALL_TYPES(REGISTER_SELECT);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
// Registration of the GPU implementations.
|
// Registration of the GPU implementations.
|
||||||
#define REGISTER_SELECT_GPU(type) \
|
#define REGISTER_SELECT_GPU(type) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||||
SelectOp<GPUDevice, type>);
|
SelectOp<GPUDevice, type>); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||||
|
SelectV2Op<GPUDevice, type>);
|
||||||
|
|
||||||
REGISTER_SELECT_GPU(bool);
|
REGISTER_SELECT_GPU(bool);
|
||||||
REGISTER_SELECT_GPU(Eigen::half);
|
REGISTER_SELECT_GPU(Eigen::half);
|
||||||
@ -174,9 +294,12 @@ REGISTER_SELECT_GPU(complex128);
|
|||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
// Registration of the SYCL implementations.
|
// Registration of the SYCL implementations.
|
||||||
#define REGISTER_SELECT_SYCL(type) \
|
#define REGISTER_SELECT_SYCL(type) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||||
|
SelectOp<SYCLDevice, type>); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||||
SelectOp<SYCLDevice, type>);
|
SelectOp<SYCLDevice, type>);
|
||||||
|
|
||||||
REGISTER_SELECT_SYCL(float);
|
REGISTER_SELECT_SYCL(float);
|
||||||
@ -324,10 +447,35 @@ struct BatchSelectFunctor<CPUDevice, T> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T, int NDIMS>
|
||||||
|
struct BCastSelectFunctorBase {
|
||||||
|
void operator()(const Device& d,
|
||||||
|
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
||||||
|
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
||||||
|
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
||||||
|
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
||||||
|
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
||||||
|
.select(then_tensor.broadcast(then_bcast),
|
||||||
|
else_tensor.broadcast(else_bcast));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int NDIMS>
|
||||||
|
struct BCastSelectFunctor<CPUDevice, T, NDIMS>
|
||||||
|
: BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct BatchSelectFunctor<SYCLDevice, T>
|
struct BatchSelectFunctor<SYCLDevice, T>
|
||||||
: BatchSelectFunctorBase<SYCLDevice, T> {};
|
: BatchSelectFunctorBase<SYCLDevice, T> {};
|
||||||
|
|
||||||
|
template <typename T, int NDIMS>
|
||||||
|
struct BCastSelectFunctor<SYCLDevice, T, NDIMS>
|
||||||
|
: BCastSelectFunctorBase<SYCLDevice, T, NDIMS> {};
|
||||||
|
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
@ -1208,6 +1208,18 @@ struct BatchSelectFunctor {
|
|||||||
typename TTypes<T>::ConstMatrix else_flat_outer_dims);
|
typename TTypes<T>::ConstMatrix else_flat_outer_dims);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T, int NDIMS>
|
||||||
|
struct BCastSelectFunctor {
|
||||||
|
void operator()(const Device& d,
|
||||||
|
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
||||||
|
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
||||||
|
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
||||||
|
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||||
|
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace functor
|
} // end namespace functor
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
@ -828,6 +828,57 @@ REGISTER_OP("Select")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SelectV2")
|
||||||
|
.Input("condition: bool")
|
||||||
|
.Input("t: T")
|
||||||
|
.Input("e: T")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: type")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
auto* handle_data_1 = c->input_handle_shapes_and_types(1);
|
||||||
|
auto* handle_data_2 = c->input_handle_shapes_and_types(2);
|
||||||
|
// Merge handle shape and dtype if applicable.
|
||||||
|
if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
|
||||||
|
const auto size = handle_data_1->size();
|
||||||
|
std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
|
||||||
|
if (size != handle_data_2->size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Trying to merge handles pointing to different numbers of "
|
||||||
|
"tensors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
|
||||||
|
const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
|
||||||
|
if (s1.dtype != s2.dtype) {
|
||||||
|
// TODO(apassos) resolve this in the manner of b/32476923
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Trying to merge handles pointing to different dtypes.");
|
||||||
|
}
|
||||||
|
merged_handle_data[i].dtype = s1.dtype;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
c->set_output_handle_shapes_and_types(0, merged_handle_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The inputs 'cond', 'then', and 'else' must be broadcastable.
|
||||||
|
// TODO (yongtang): Consolidate 3-ary broadcast instead of
|
||||||
|
// multiple 2-ary broadcast.
|
||||||
|
ShapeHandle cond = c->input(0);
|
||||||
|
ShapeHandle then = c->input(1);
|
||||||
|
ShapeHandle else_ = c->input(2);
|
||||||
|
ShapeHandle other;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other));
|
||||||
|
ShapeHandle output;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output));
|
||||||
|
c->set_output(0, output);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
REGISTER_OP("MatMul")
|
REGISTER_OP("MatMul")
|
||||||
|
@ -37,10 +37,10 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class WhereOpTest(test.TestCase):
|
class WhereOpTest(test.TestCase):
|
||||||
|
|
||||||
def _testWhere(self, x, truth, expected_err_re=None):
|
def _testWhere(self, x, truth, expected_err_re=None, fn=array_ops.where):
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
ans = array_ops.where(x)
|
ans = fn(x)
|
||||||
self.assertEqual([None, x.ndim], ans.get_shape().as_list())
|
self.assertTrue(ans.get_shape().is_compatible_with([None, x.ndim]))
|
||||||
if expected_err_re is None:
|
if expected_err_re is None:
|
||||||
tf_ans = self.evaluate(ans)
|
tf_ans = self.evaluate(ans)
|
||||||
self.assertAllClose(tf_ans, truth, atol=1e-10)
|
self.assertAllClose(tf_ans, truth, atol=1e-10)
|
||||||
@ -48,44 +48,40 @@ class WhereOpTest(test.TestCase):
|
|||||||
with self.assertRaisesOpError(expected_err_re):
|
with self.assertRaisesOpError(expected_err_re):
|
||||||
self.evaluate(ans)
|
self.evaluate(ans)
|
||||||
|
|
||||||
def testWrongNumbers(self):
|
def _testWrongNumbers(self, fn=array_ops.where):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
array_ops.where([False, True], [1, 2], None)
|
fn([False, True], [1, 2], None)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
array_ops.where([False, True], None, [1, 2])
|
fn([False, True], None, [1, 2])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
def _testBasicVec(self, fn=array_ops.where):
|
||||||
def testBasicVec(self):
|
|
||||||
x = np.asarray([True, False])
|
x = np.asarray([True, False])
|
||||||
truth = np.asarray([[0]], dtype=np.int64)
|
truth = np.asarray([[0]], dtype=np.int64)
|
||||||
self._testWhere(x, truth)
|
self._testWhere(x, truth, None, fn)
|
||||||
|
|
||||||
x = np.asarray([False, True, False])
|
x = np.asarray([False, True, False])
|
||||||
truth = np.asarray([[1]], dtype=np.int64)
|
truth = np.asarray([[1]], dtype=np.int64)
|
||||||
self._testWhere(x, truth)
|
self._testWhere(x, truth, None, fn)
|
||||||
|
|
||||||
x = np.asarray([False, False, True, False, True])
|
x = np.asarray([False, False, True, False, True])
|
||||||
truth = np.asarray([[2], [4]], dtype=np.int64)
|
truth = np.asarray([[2], [4]], dtype=np.int64)
|
||||||
self._testWhere(x, truth)
|
self._testWhere(x, truth, None, fn)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
def _testRandomVec(self, fn=array_ops.where):
|
||||||
def testRandomVec(self):
|
|
||||||
x = np.random.rand(1000000) > 0.5
|
x = np.random.rand(1000000) > 0.5
|
||||||
truth = np.vstack([np.where(x)[0].astype(np.int64)]).T
|
truth = np.vstack([np.where(x)[0].astype(np.int64)]).T
|
||||||
self._testWhere(x, truth)
|
self._testWhere(x, truth, None, fn)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
def _testBasicMat(self, fn=array_ops.where):
|
||||||
def testBasicMat(self):
|
|
||||||
x = np.asarray([[True, False], [True, False]])
|
x = np.asarray([[True, False], [True, False]])
|
||||||
|
|
||||||
# Ensure RowMajor mode
|
# Ensure RowMajor mode
|
||||||
truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64)
|
truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64)
|
||||||
|
|
||||||
self._testWhere(x, truth)
|
self._testWhere(x, truth, None, fn)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
def _testBasic3Tensor(self, fn=array_ops.where):
|
||||||
def testBasic3Tensor(self):
|
|
||||||
x = np.asarray([[[True, False], [True, False]],
|
x = np.asarray([[[True, False], [True, False]],
|
||||||
[[False, True], [False, True]],
|
[[False, True], [False, True]],
|
||||||
[[False, False], [False, True]]])
|
[[False, False], [False, True]]])
|
||||||
@ -94,15 +90,41 @@ class WhereOpTest(test.TestCase):
|
|||||||
truth = np.asarray(
|
truth = np.asarray(
|
||||||
[[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64)
|
[[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64)
|
||||||
|
|
||||||
self._testWhere(x, truth)
|
self._testWhere(x, truth, None, fn)
|
||||||
|
|
||||||
def _testRandom(self, dtype, expected_err_re=None):
|
def _testRandom(self, dtype, expected_err_re=None, fn=array_ops.where):
|
||||||
shape = [127, 33, 53]
|
shape = [127, 33, 53]
|
||||||
x = np.random.randn(*shape) + 1j * np.random.randn(*shape)
|
x = np.random.randn(*shape) + 1j * np.random.randn(*shape)
|
||||||
x = (np.random.randn(*shape) > 0).astype(dtype)
|
x = (np.random.randn(*shape) > 0).astype(dtype)
|
||||||
truth = np.where(np.abs(x) > 0) # Tuples of indices by axis.
|
truth = np.where(np.abs(x) > 0) # Tuples of indices by axis.
|
||||||
truth = np.vstack(truth).T # Convert to [num_true, indices].
|
truth = np.vstack(truth).T # Convert to [num_true, indices].
|
||||||
self._testWhere(x, truth, expected_err_re)
|
self._testWhere(x, truth, expected_err_re, fn)
|
||||||
|
|
||||||
|
def _testThreeArgument(self, fn=array_ops.where):
|
||||||
|
x = np.array([[-2, 3, -1], [1, -3, -3]])
|
||||||
|
np_val = np.where(x > 0, x * x, -x)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
tf_val = self.evaluate(fn(constant_op.constant(x) > 0, x * x, -x))
|
||||||
|
self.assertAllEqual(tf_val, np_val)
|
||||||
|
|
||||||
|
def testWrongNumbers(self):
|
||||||
|
self._testWrongNumbers()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBasicVec(self):
|
||||||
|
self._testBasicVec()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testRandomVec(self):
|
||||||
|
self._testRandomVec()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBasicMat(self):
|
||||||
|
self._testBasicMat()
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBasic3Tensor(self):
|
||||||
|
self._testBasic3Tensor()
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testRandomBool(self):
|
def testRandomBool(self):
|
||||||
@ -146,12 +168,95 @@ class WhereOpTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testThreeArgument(self):
|
def testThreeArgument(self):
|
||||||
x = np.array([[-2, 3, -1], [1, -3, -3]])
|
self._testThreeArgument()
|
||||||
np_val = np.where(x > 0, x * x, -x)
|
|
||||||
with self.session(use_gpu=True):
|
def testV2WrongNumbers(self):
|
||||||
tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval()
|
self._testWrongNumbers(array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2BasicVec(self):
|
||||||
|
self._testBasicVec(array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomVec(self):
|
||||||
|
self._testRandomVec(array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2BasicMat(self):
|
||||||
|
self._testBasicMat(array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2Basic3Tensor(self):
|
||||||
|
self._testBasic3Tensor(array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomBool(self):
|
||||||
|
self._testRandom(np.bool, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomInt32(self):
|
||||||
|
self._testRandom(np.int32, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomInt64(self):
|
||||||
|
self._testRandom(np.int64, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomFloat(self):
|
||||||
|
self._testRandom(np.float32, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomDouble(self):
|
||||||
|
self._testRandom(np.float64, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomComplex64(self):
|
||||||
|
self._testRandom(np.complex64, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomComplex128(self):
|
||||||
|
self._testRandom(np.complex128, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomUint8(self):
|
||||||
|
self._testRandom(np.uint8, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomInt8(self):
|
||||||
|
self._testRandom(np.int8, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2RandomInt16(self):
|
||||||
|
self._testRandom(np.int16, None, array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2ThreeArgument(self):
|
||||||
|
self._testThreeArgument(array_ops.where_v2)
|
||||||
|
|
||||||
|
def testV2Broadcasting(self):
|
||||||
|
f = np.random.normal(0, 1, (3, 5, 1, 1))
|
||||||
|
x = np.zeros((7, 11))
|
||||||
|
y = np.ones((7, 11))
|
||||||
|
np_val = np.where(f < 0, x, y)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
tf_val = self.evaluate(
|
||||||
|
array_ops.where_v2(constant_op.constant(f) < 0, x, y))
|
||||||
self.assertAllEqual(tf_val, np_val)
|
self.assertAllEqual(tf_val, np_val)
|
||||||
|
|
||||||
|
def testV2ScalarBroadcasting(self):
|
||||||
|
x = np.zeros((7, 11))
|
||||||
|
y = np.ones((7, 11))
|
||||||
|
np_val = np.where(True, x, y)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
tf_val = self.evaluate(
|
||||||
|
array_ops.where_v2(
|
||||||
|
constant_op.constant(True, dtype=dtypes.bool), x, y))
|
||||||
|
self.assertAllEqual(tf_val, np_val)
|
||||||
|
|
||||||
|
def testV2VectorBroadcasting(self):
|
||||||
|
x = np.zeros(7)
|
||||||
|
y = np.ones(7)
|
||||||
|
np_val = np.where([True], x, y)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
tf_val = self.evaluate(
|
||||||
|
array_ops.where_v2(
|
||||||
|
constant_op.constant([True], dtype=dtypes.bool), x, y))
|
||||||
|
self.assertAllEqual(tf_val, np_val)
|
||||||
|
|
||||||
|
def testV2PredBroadcasting(self):
|
||||||
|
pred = np.array([1, 0, 0]).reshape((3, 1))
|
||||||
|
x = np.random.randn(3, 4)
|
||||||
|
y = np.random.randn(3, 4)
|
||||||
|
np_val = np.where(pred, x, y)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
tf_val = self.evaluate(array_ops.where_v2(pred, x, y))
|
||||||
|
self.assertAllClose(tf_val, np_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBatchSelect(self):
|
def testBatchSelect(self):
|
||||||
x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192]
|
x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192]
|
||||||
|
@ -3163,7 +3163,11 @@ def squeeze_v2(input, axis=None, name=None):
|
|||||||
return squeeze(input, axis, name)
|
return squeeze(input, axis, name)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("where")
|
@tf_export(v1=["where"])
|
||||||
|
@deprecation.deprecated(
|
||||||
|
date=None,
|
||||||
|
instructions="Use tf.where in 2.0, "
|
||||||
|
"which has the same broadcast rule as np.where")
|
||||||
@dispatch.add_dispatch_support
|
@dispatch.add_dispatch_support
|
||||||
def where(condition, x=None, y=None, name=None):
|
def where(condition, x=None, y=None, name=None):
|
||||||
"""Return the elements, either from `x` or `y`, depending on the `condition`.
|
"""Return the elements, either from `x` or `y`, depending on the `condition`.
|
||||||
@ -3217,6 +3221,48 @@ def where(condition, x=None, y=None, name=None):
|
|||||||
raise ValueError("x and y must both be non-None or both be None.")
|
raise ValueError("x and y must both be non-None or both be None.")
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("where", v1=["where_v2"])
|
||||||
|
def where_v2(condition, x=None, y=None, name=None):
|
||||||
|
"""Return the elements, either from `x` or `y`, depending on the `condition`.
|
||||||
|
|
||||||
|
If both `x` and `y` are None, then this operation returns the coordinates of
|
||||||
|
true elements of `condition`. The coordinates are returned in a 2-D tensor
|
||||||
|
where the first dimension (rows) represents the number of true elements, and
|
||||||
|
the second dimension (columns) represents the coordinates of the true
|
||||||
|
elements. Keep in mind, the shape of the output tensor can vary depending on
|
||||||
|
how many true values there are in input. Indices are output in row-major
|
||||||
|
order.
|
||||||
|
If both non-None, `condition`, `x` and `y` must be broadcastable to the same
|
||||||
|
shape.
|
||||||
|
The `condition` tensor acts as a mask that chooses, based on the value at each
|
||||||
|
element, whether the corresponding element / row in the output should be taken
|
||||||
|
from `x` (if true) or `y` (if false).
|
||||||
|
Args:
|
||||||
|
condition: A `Tensor` of type `bool`
|
||||||
|
x: A Tensor which is of the same type as `y`, and may be broadcastable with
|
||||||
|
`condition` and `y`.
|
||||||
|
y: A Tensor which is of the same type as `x`, and may be broadcastable with
|
||||||
|
`condition` and `x`.
|
||||||
|
name: A name of the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` with the same type as `x` and `y`, and shape that
|
||||||
|
is broadcasted from `condition`, `x`, and `y`, if `x`, `y` are non-None.
|
||||||
|
A `Tensor` with shape `(num_true, dim_size(condition))`.
|
||||||
|
Raises:
|
||||||
|
ValueError: When exactly one of `x` or `y` is non-None.
|
||||||
|
"""
|
||||||
|
if x is None and y is None:
|
||||||
|
with ops.name_scope(name, "Where", [condition]) as name:
|
||||||
|
condition = ops.convert_to_tensor(
|
||||||
|
condition, preferred_dtype=dtypes.bool, name="condition")
|
||||||
|
return gen_array_ops.where(condition=condition, name=name)
|
||||||
|
elif x is not None and y is not None:
|
||||||
|
return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name)
|
||||||
|
else:
|
||||||
|
raise ValueError("x and y must both be non-None or both be None.")
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=redefined-builtin
|
# pylint: disable=redefined-builtin
|
||||||
@tf_export(v1=["reverse_sequence"])
|
@tf_export(v1=["reverse_sequence"])
|
||||||
@deprecation.deprecated_args(None,
|
@deprecation.deprecated_args(None,
|
||||||
|
@ -2424,6 +2424,10 @@ tf_module {
|
|||||||
name: "where"
|
name: "where"
|
||||||
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "where_v2"
|
||||||
|
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "while_loop"
|
name: "while_loop"
|
||||||
argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], "
|
argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], "
|
||||||
|
@ -3348,6 +3348,10 @@ tf_module {
|
|||||||
name: "Select"
|
name: "Select"
|
||||||
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "SelectV2"
|
||||||
|
argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "SelfAdjointEig"
|
name: "SelfAdjointEig"
|
||||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -3348,6 +3348,10 @@ tf_module {
|
|||||||
name: "Select"
|
name: "Select"
|
||||||
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "SelectV2"
|
||||||
|
argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "SelfAdjointEig"
|
name: "SelfAdjointEig"
|
||||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -1543,6 +1543,10 @@ renames = {
|
|||||||
'tf.compat.v1.variables_initializer',
|
'tf.compat.v1.variables_initializer',
|
||||||
'tf.verify_tensor_all_finite':
|
'tf.verify_tensor_all_finite':
|
||||||
'tf.compat.v1.verify_tensor_all_finite',
|
'tf.compat.v1.verify_tensor_all_finite',
|
||||||
|
'tf.where':
|
||||||
|
'tf.compat.v1.where',
|
||||||
|
'tf.where_v2':
|
||||||
|
'tf.compat.v2.where',
|
||||||
'tf.wrap_function':
|
'tf.wrap_function':
|
||||||
'tf.compat.v1.wrap_function',
|
'tf.compat.v1.wrap_function',
|
||||||
'tf.write_file':
|
'tf.write_file':
|
||||||
|
Loading…
Reference in New Issue
Block a user