Automated rollback of commit 60524c4167
. Revert #15982.
PiperOrigin-RevId: 247032095
This commit is contained in:
parent
4caf8b10bb
commit
14ee9008e5
@ -124,7 +124,7 @@ from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d
|
||||
from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ['nest', 'broadcast_to', 'where']
|
||||
_allowed_symbols = ['nest']
|
||||
_nest_allowed_symbols = [
|
||||
'assert_same_structure',
|
||||
'is_nested',
|
||||
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "SelectV2"
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "SelectV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -23,22 +23,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
template <typename T, int NDIMS>
|
||||
struct BCastSelectFunctor<GPUDevice, T, NDIMS> {
|
||||
void operator()(const GPUDevice& d,
|
||||
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
||||
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
||||
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
||||
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
||||
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
||||
.select(then_tensor.broadcast(then_bcast),
|
||||
else_tensor.broadcast(else_bcast));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SelectFunctor<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
|
||||
@ -105,15 +89,10 @@ struct BatchSelectFunctor<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#define SELECT_FUNCTOR(T) \
|
||||
template struct SelectFunctor<GPUDevice, T>; \
|
||||
template struct SelectScalarFunctor<GPUDevice, T>; \
|
||||
template struct BatchSelectFunctor<GPUDevice, T>; \
|
||||
template struct BCastSelectFunctor<GPUDevice, T, 1>; \
|
||||
template struct BCastSelectFunctor<GPUDevice, T, 2>; \
|
||||
template struct BCastSelectFunctor<GPUDevice, T, 3>; \
|
||||
template struct BCastSelectFunctor<GPUDevice, T, 4>; \
|
||||
template struct BCastSelectFunctor<GPUDevice, T, 5>;
|
||||
#define SELECT_FUNCTOR(T) \
|
||||
template struct SelectFunctor<GPUDevice, T>; \
|
||||
template struct SelectScalarFunctor<GPUDevice, T>; \
|
||||
template struct BatchSelectFunctor<GPUDevice, T>;
|
||||
|
||||
SELECT_FUNCTOR(bool);
|
||||
SELECT_FUNCTOR(Eigen::half);
|
||||
|
@ -143,141 +143,21 @@ class SelectOp : public OpKernel {
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
|
||||
};
|
||||
template <typename Device, typename T>
|
||||
class SelectV2Op : public OpKernel {
|
||||
public:
|
||||
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* cond;
|
||||
const Tensor* then;
|
||||
const Tensor* else_;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
|
||||
|
||||
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
|
||||
// This matches the behavior of numpy.
|
||||
// TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple
|
||||
// 2-ary broadcast.
|
||||
|
||||
// Combine `then` and `else`.
|
||||
BCast then_else_bcast(BCast::FromShape(then->shape()),
|
||||
BCast::FromShape(else_->shape()), false);
|
||||
OP_REQUIRES(ctx, then_else_bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"then ", then->shape().DebugString(), " and else ",
|
||||
else_->shape().DebugString(), " must be broadcastable"));
|
||||
// Combine `cond` with `then` and `else`.
|
||||
BCast bcast(
|
||||
BCast::FromShape(cond->shape()),
|
||||
BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())),
|
||||
false);
|
||||
OP_REQUIRES(ctx, bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"condition ", cond->shape().DebugString(), ", then ",
|
||||
then->shape().DebugString(), ", and else ",
|
||||
else_->shape().DebugString(), " must be broadcastable"));
|
||||
|
||||
// Broadcast `cond`, `then` and `else` to combined shape,
|
||||
// in order to obtain the reshape.
|
||||
BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||
BCast::FromShape(cond->shape()), false);
|
||||
BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||
BCast::FromShape(then->shape()), false);
|
||||
BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
|
||||
BCast::FromShape(else_->shape()), false);
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(),
|
||||
errors::InvalidArgument("condition ", cond->shape().DebugString(),
|
||||
", then ", then->shape().DebugString(),
|
||||
", and else ", else_->shape().DebugString(),
|
||||
" must be broadcastable"));
|
||||
|
||||
// Combined shape should be the final shape.
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
cond_bcast.output_shape() == bcast.output_shape() &&
|
||||
then_bcast.output_shape() == bcast.output_shape() &&
|
||||
else_bcast.output_shape() == bcast.output_shape(),
|
||||
errors::InvalidArgument("condition ", cond->shape().DebugString(),
|
||||
", then ", then->shape().DebugString(),
|
||||
", and else ", else_->shape().DebugString(),
|
||||
" must be broadcastable to the same shape"));
|
||||
|
||||
Tensor* output = nullptr;
|
||||
const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
|
||||
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
|
||||
{"t", "e"}, "output", output_shape, &output));
|
||||
|
||||
if (output->NumElements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
#define HANDLE_DIM(NDIMS) \
|
||||
{ \
|
||||
functor::BCastSelectFunctor<Device, T, NDIMS> func; \
|
||||
func(ctx->eigen_device<Device>(), \
|
||||
output->shaped<T, NDIMS>(bcast.result_shape()), \
|
||||
cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \
|
||||
then->template shaped<T, NDIMS>(then_bcast.y_reshape()), \
|
||||
else_->template shaped<T, NDIMS>(else_bcast.y_reshape()), \
|
||||
BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()), \
|
||||
BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()), \
|
||||
BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast())); \
|
||||
}
|
||||
|
||||
const int ndims = static_cast<int>(bcast.result_shape().size());
|
||||
switch (ndims) {
|
||||
case 1:
|
||||
HANDLE_DIM(1);
|
||||
break;
|
||||
case 2:
|
||||
HANDLE_DIM(2);
|
||||
break;
|
||||
case 3:
|
||||
HANDLE_DIM(3);
|
||||
break;
|
||||
case 4:
|
||||
HANDLE_DIM(4);
|
||||
break;
|
||||
case 5:
|
||||
HANDLE_DIM(5);
|
||||
break;
|
||||
default:
|
||||
ctx->SetStatus(errors::Unimplemented(
|
||||
"Broadcast between ", ctx->input(0).shape().DebugString(), " and ",
|
||||
ctx->input(1).shape().DebugString(), " is not supported yet."));
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op);
|
||||
};
|
||||
|
||||
#define REGISTER_SELECT(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
SelectOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
SelectV2Op<CPUDevice, type>);
|
||||
#define REGISTER_SELECT(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
SelectOp<CPUDevice, type>);
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_SELECT);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
// Registration of the GPU implementations.
|
||||
#define REGISTER_SELECT_GPU(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
SelectOp<GPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
SelectV2Op<GPUDevice, type>);
|
||||
#define REGISTER_SELECT_GPU(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
SelectOp<GPUDevice, type>);
|
||||
|
||||
REGISTER_SELECT_GPU(bool);
|
||||
REGISTER_SELECT_GPU(Eigen::half);
|
||||
@ -294,12 +174,9 @@ REGISTER_SELECT_GPU(complex128);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
// Registration of the SYCL implementations.
|
||||
#define REGISTER_SELECT_SYCL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
SelectOp<SYCLDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
#define REGISTER_SELECT_SYCL(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||
SelectOp<SYCLDevice, type>);
|
||||
|
||||
REGISTER_SELECT_SYCL(float);
|
||||
@ -447,35 +324,10 @@ struct BatchSelectFunctor<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T, int NDIMS>
|
||||
struct BCastSelectFunctorBase {
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
||||
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
||||
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
||||
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
|
||||
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
|
||||
.select(then_tensor.broadcast(then_bcast),
|
||||
else_tensor.broadcast(else_bcast));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int NDIMS>
|
||||
struct BCastSelectFunctor<CPUDevice, T, NDIMS>
|
||||
: BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename T>
|
||||
struct BatchSelectFunctor<SYCLDevice, T>
|
||||
: BatchSelectFunctorBase<SYCLDevice, T> {};
|
||||
|
||||
template <typename T, int NDIMS>
|
||||
struct BCastSelectFunctor<SYCLDevice, T, NDIMS>
|
||||
: BCastSelectFunctorBase<SYCLDevice, T, NDIMS> {};
|
||||
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace functor
|
||||
|
@ -1208,18 +1208,6 @@ struct BatchSelectFunctor {
|
||||
typename TTypes<T>::ConstMatrix else_flat_outer_dims);
|
||||
};
|
||||
|
||||
template <typename Device, typename T, int NDIMS>
|
||||
struct BCastSelectFunctor {
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<T, NDIMS>::Tensor output_tensor,
|
||||
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
|
||||
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
|
||||
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
|
||||
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -828,57 +828,6 @@ REGISTER_OP("Select")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("SelectV2")
|
||||
.Input("condition: bool")
|
||||
.Input("t: T")
|
||||
.Input("e: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
auto* handle_data_1 = c->input_handle_shapes_and_types(1);
|
||||
auto* handle_data_2 = c->input_handle_shapes_and_types(2);
|
||||
// Merge handle shape and dtype if applicable.
|
||||
if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
|
||||
const auto size = handle_data_1->size();
|
||||
std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
|
||||
if (size != handle_data_2->size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to merge handles pointing to different numbers of "
|
||||
"tensors.");
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
|
||||
const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
|
||||
if (s1.dtype != s2.dtype) {
|
||||
// TODO(apassos) resolve this in the manner of b/32476923
|
||||
return errors::InvalidArgument(
|
||||
"Trying to merge handles pointing to different dtypes.");
|
||||
}
|
||||
merged_handle_data[i].dtype = s1.dtype;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
|
||||
}
|
||||
|
||||
c->set_output_handle_shapes_and_types(0, merged_handle_data);
|
||||
}
|
||||
|
||||
// The inputs 'cond', 'then', and 'else' must be broadcastable.
|
||||
// TODO (yongtang): Consolidate 3-ary broadcast instead of
|
||||
// multiple 2-ary broadcast.
|
||||
ShapeHandle cond = c->input(0);
|
||||
ShapeHandle then = c->input(1);
|
||||
ShapeHandle else_ = c->input(2);
|
||||
ShapeHandle other;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other));
|
||||
ShapeHandle output;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output));
|
||||
c->set_output(0, output);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("MatMul")
|
||||
|
@ -37,10 +37,10 @@ from tensorflow.python.platform import test
|
||||
|
||||
class WhereOpTest(test.TestCase):
|
||||
|
||||
def _testWhere(self, x, truth, expected_err_re=None, fn=array_ops.where):
|
||||
def _testWhere(self, x, truth, expected_err_re=None):
|
||||
with self.cached_session(use_gpu=True):
|
||||
ans = fn(x)
|
||||
self.assertTrue(ans.get_shape().is_compatible_with([None, x.ndim]))
|
||||
ans = array_ops.where(x)
|
||||
self.assertEqual([None, x.ndim], ans.get_shape().as_list())
|
||||
if expected_err_re is None:
|
||||
tf_ans = self.evaluate(ans)
|
||||
self.assertAllClose(tf_ans, truth, atol=1e-10)
|
||||
@ -48,40 +48,44 @@ class WhereOpTest(test.TestCase):
|
||||
with self.assertRaisesOpError(expected_err_re):
|
||||
self.evaluate(ans)
|
||||
|
||||
def _testWrongNumbers(self, fn=array_ops.where):
|
||||
def testWrongNumbers(self):
|
||||
with self.session(use_gpu=True):
|
||||
with self.assertRaises(ValueError):
|
||||
fn([False, True], [1, 2], None)
|
||||
array_ops.where([False, True], [1, 2], None)
|
||||
with self.assertRaises(ValueError):
|
||||
fn([False, True], None, [1, 2])
|
||||
array_ops.where([False, True], None, [1, 2])
|
||||
|
||||
def _testBasicVec(self, fn=array_ops.where):
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicVec(self):
|
||||
x = np.asarray([True, False])
|
||||
truth = np.asarray([[0]], dtype=np.int64)
|
||||
self._testWhere(x, truth, None, fn)
|
||||
self._testWhere(x, truth)
|
||||
|
||||
x = np.asarray([False, True, False])
|
||||
truth = np.asarray([[1]], dtype=np.int64)
|
||||
self._testWhere(x, truth, None, fn)
|
||||
self._testWhere(x, truth)
|
||||
|
||||
x = np.asarray([False, False, True, False, True])
|
||||
truth = np.asarray([[2], [4]], dtype=np.int64)
|
||||
self._testWhere(x, truth, None, fn)
|
||||
self._testWhere(x, truth)
|
||||
|
||||
def _testRandomVec(self, fn=array_ops.where):
|
||||
@test_util.run_deprecated_v1
|
||||
def testRandomVec(self):
|
||||
x = np.random.rand(1000000) > 0.5
|
||||
truth = np.vstack([np.where(x)[0].astype(np.int64)]).T
|
||||
self._testWhere(x, truth, None, fn)
|
||||
self._testWhere(x, truth)
|
||||
|
||||
def _testBasicMat(self, fn=array_ops.where):
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicMat(self):
|
||||
x = np.asarray([[True, False], [True, False]])
|
||||
|
||||
# Ensure RowMajor mode
|
||||
truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64)
|
||||
|
||||
self._testWhere(x, truth, None, fn)
|
||||
self._testWhere(x, truth)
|
||||
|
||||
def _testBasic3Tensor(self, fn=array_ops.where):
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasic3Tensor(self):
|
||||
x = np.asarray([[[True, False], [True, False]],
|
||||
[[False, True], [False, True]],
|
||||
[[False, False], [False, True]]])
|
||||
@ -90,41 +94,15 @@ class WhereOpTest(test.TestCase):
|
||||
truth = np.asarray(
|
||||
[[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64)
|
||||
|
||||
self._testWhere(x, truth, None, fn)
|
||||
self._testWhere(x, truth)
|
||||
|
||||
def _testRandom(self, dtype, expected_err_re=None, fn=array_ops.where):
|
||||
def _testRandom(self, dtype, expected_err_re=None):
|
||||
shape = [127, 33, 53]
|
||||
x = np.random.randn(*shape) + 1j * np.random.randn(*shape)
|
||||
x = (np.random.randn(*shape) > 0).astype(dtype)
|
||||
truth = np.where(np.abs(x) > 0) # Tuples of indices by axis.
|
||||
truth = np.vstack(truth).T # Convert to [num_true, indices].
|
||||
self._testWhere(x, truth, expected_err_re, fn)
|
||||
|
||||
def _testThreeArgument(self, fn=array_ops.where):
|
||||
x = np.array([[-2, 3, -1], [1, -3, -3]])
|
||||
np_val = np.where(x > 0, x * x, -x)
|
||||
with self.test_session(use_gpu=True):
|
||||
tf_val = self.evaluate(fn(constant_op.constant(x) > 0, x * x, -x))
|
||||
self.assertAllEqual(tf_val, np_val)
|
||||
|
||||
def testWrongNumbers(self):
|
||||
self._testWrongNumbers()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicVec(self):
|
||||
self._testBasicVec()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRandomVec(self):
|
||||
self._testRandomVec()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicMat(self):
|
||||
self._testBasicMat()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasic3Tensor(self):
|
||||
self._testBasic3Tensor()
|
||||
self._testWhere(x, truth, expected_err_re)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRandomBool(self):
|
||||
@ -168,95 +146,12 @@ class WhereOpTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testThreeArgument(self):
|
||||
self._testThreeArgument()
|
||||
|
||||
def testV2WrongNumbers(self):
|
||||
self._testWrongNumbers(array_ops.where_v2)
|
||||
|
||||
def testV2BasicVec(self):
|
||||
self._testBasicVec(array_ops.where_v2)
|
||||
|
||||
def testV2RandomVec(self):
|
||||
self._testRandomVec(array_ops.where_v2)
|
||||
|
||||
def testV2BasicMat(self):
|
||||
self._testBasicMat(array_ops.where_v2)
|
||||
|
||||
def testV2Basic3Tensor(self):
|
||||
self._testBasic3Tensor(array_ops.where_v2)
|
||||
|
||||
def testV2RandomBool(self):
|
||||
self._testRandom(np.bool, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomInt32(self):
|
||||
self._testRandom(np.int32, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomInt64(self):
|
||||
self._testRandom(np.int64, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomFloat(self):
|
||||
self._testRandom(np.float32, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomDouble(self):
|
||||
self._testRandom(np.float64, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomComplex64(self):
|
||||
self._testRandom(np.complex64, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomComplex128(self):
|
||||
self._testRandom(np.complex128, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomUint8(self):
|
||||
self._testRandom(np.uint8, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomInt8(self):
|
||||
self._testRandom(np.int8, None, array_ops.where_v2)
|
||||
|
||||
def testV2RandomInt16(self):
|
||||
self._testRandom(np.int16, None, array_ops.where_v2)
|
||||
|
||||
def testV2ThreeArgument(self):
|
||||
self._testThreeArgument(array_ops.where_v2)
|
||||
|
||||
def testV2Broadcasting(self):
|
||||
f = np.random.normal(0, 1, (3, 5, 1, 1))
|
||||
x = np.zeros((7, 11))
|
||||
y = np.ones((7, 11))
|
||||
np_val = np.where(f < 0, x, y)
|
||||
with self.test_session(use_gpu=True):
|
||||
tf_val = self.evaluate(
|
||||
array_ops.where_v2(constant_op.constant(f) < 0, x, y))
|
||||
x = np.array([[-2, 3, -1], [1, -3, -3]])
|
||||
np_val = np.where(x > 0, x * x, -x)
|
||||
with self.session(use_gpu=True):
|
||||
tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval()
|
||||
self.assertAllEqual(tf_val, np_val)
|
||||
|
||||
def testV2ScalarBroadcasting(self):
|
||||
x = np.zeros((7, 11))
|
||||
y = np.ones((7, 11))
|
||||
np_val = np.where(True, x, y)
|
||||
with self.test_session(use_gpu=True):
|
||||
tf_val = self.evaluate(
|
||||
array_ops.where_v2(
|
||||
constant_op.constant(True, dtype=dtypes.bool), x, y))
|
||||
self.assertAllEqual(tf_val, np_val)
|
||||
|
||||
def testV2VectorBroadcasting(self):
|
||||
x = np.zeros(7)
|
||||
y = np.ones(7)
|
||||
np_val = np.where([True], x, y)
|
||||
with self.test_session(use_gpu=True):
|
||||
tf_val = self.evaluate(
|
||||
array_ops.where_v2(
|
||||
constant_op.constant([True], dtype=dtypes.bool), x, y))
|
||||
self.assertAllEqual(tf_val, np_val)
|
||||
|
||||
def testV2PredBroadcasting(self):
|
||||
pred = np.array([1, 0, 0]).reshape((3, 1))
|
||||
x = np.random.randn(3, 4)
|
||||
y = np.random.randn(3, 4)
|
||||
np_val = np.where(pred, x, y)
|
||||
with self.test_session(use_gpu=True):
|
||||
tf_val = self.evaluate(array_ops.where_v2(pred, x, y))
|
||||
self.assertAllClose(tf_val, np_val)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchSelect(self):
|
||||
x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192]
|
||||
|
@ -3163,11 +3163,7 @@ def squeeze_v2(input, axis=None, name=None):
|
||||
return squeeze(input, axis, name)
|
||||
|
||||
|
||||
@tf_export(v1=["where"])
|
||||
@deprecation.deprecated(
|
||||
date=None,
|
||||
instructions="Use tf.where in 2.0, "
|
||||
"which has the same broadcast rule as np.where")
|
||||
@tf_export("where")
|
||||
@dispatch.add_dispatch_support
|
||||
def where(condition, x=None, y=None, name=None):
|
||||
"""Return the elements, either from `x` or `y`, depending on the `condition`.
|
||||
@ -3221,48 +3217,6 @@ def where(condition, x=None, y=None, name=None):
|
||||
raise ValueError("x and y must both be non-None or both be None.")
|
||||
|
||||
|
||||
@tf_export("where", v1=["where_v2"])
|
||||
def where_v2(condition, x=None, y=None, name=None):
|
||||
"""Return the elements, either from `x` or `y`, depending on the `condition`.
|
||||
|
||||
If both `x` and `y` are None, then this operation returns the coordinates of
|
||||
true elements of `condition`. The coordinates are returned in a 2-D tensor
|
||||
where the first dimension (rows) represents the number of true elements, and
|
||||
the second dimension (columns) represents the coordinates of the true
|
||||
elements. Keep in mind, the shape of the output tensor can vary depending on
|
||||
how many true values there are in input. Indices are output in row-major
|
||||
order.
|
||||
If both non-None, `condition`, `x` and `y` must be broadcastable to the same
|
||||
shape.
|
||||
The `condition` tensor acts as a mask that chooses, based on the value at each
|
||||
element, whether the corresponding element / row in the output should be taken
|
||||
from `x` (if true) or `y` (if false).
|
||||
Args:
|
||||
condition: A `Tensor` of type `bool`
|
||||
x: A Tensor which is of the same type as `y`, and may be broadcastable with
|
||||
`condition` and `y`.
|
||||
y: A Tensor which is of the same type as `x`, and may be broadcastable with
|
||||
`condition` and `x`.
|
||||
name: A name of the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` with the same type as `x` and `y`, and shape that
|
||||
is broadcasted from `condition`, `x`, and `y`, if `x`, `y` are non-None.
|
||||
A `Tensor` with shape `(num_true, dim_size(condition))`.
|
||||
Raises:
|
||||
ValueError: When exactly one of `x` or `y` is non-None.
|
||||
"""
|
||||
if x is None and y is None:
|
||||
with ops.name_scope(name, "Where", [condition]) as name:
|
||||
condition = ops.convert_to_tensor(
|
||||
condition, preferred_dtype=dtypes.bool, name="condition")
|
||||
return gen_array_ops.where(condition=condition, name=name)
|
||||
elif x is not None and y is not None:
|
||||
return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name)
|
||||
else:
|
||||
raise ValueError("x and y must both be non-None or both be None.")
|
||||
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
@tf_export(v1=["reverse_sequence"])
|
||||
@deprecation.deprecated_args(None,
|
||||
|
@ -2424,10 +2424,6 @@ tf_module {
|
||||
name: "where"
|
||||
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "where_v2"
|
||||
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "while_loop"
|
||||
argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], "
|
||||
|
@ -3348,10 +3348,6 @@ tf_module {
|
||||
name: "Select"
|
||||
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SelectV2"
|
||||
argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SelfAdjointEig"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -3348,10 +3348,6 @@ tf_module {
|
||||
name: "Select"
|
||||
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SelectV2"
|
||||
argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SelfAdjointEig"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1543,10 +1543,6 @@ renames = {
|
||||
'tf.compat.v1.variables_initializer',
|
||||
'tf.verify_tensor_all_finite':
|
||||
'tf.compat.v1.verify_tensor_all_finite',
|
||||
'tf.where':
|
||||
'tf.compat.v1.where',
|
||||
'tf.where_v2':
|
||||
'tf.compat.v2.where',
|
||||
'tf.wrap_function':
|
||||
'tf.compat.v1.wrap_function',
|
||||
'tf.write_file':
|
||||
|
Loading…
Reference in New Issue
Block a user