Merge pull request #28616 from yongtang:15982-where-v2

PiperOrigin-RevId: 248603583
This commit is contained in:
TensorFlower Gardener 2019-05-16 14:50:26 -07:00
commit 01207a94c6
14 changed files with 661 additions and 71 deletions

View File

@ -0,0 +1,3 @@
op {
graph_op_name: "SelectV2"
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SelectV2"
visibility: HIDDEN
}

View File

@ -23,6 +23,22 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace functor { namespace functor {
template <typename T, int NDIMS>
struct BCastSelectFunctor<GPUDevice, T, NDIMS> {
void operator()(const GPUDevice& d,
typename TTypes<T, NDIMS>::Tensor output_tensor,
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
.select(then_tensor.broadcast(then_bcast),
else_tensor.broadcast(else_bcast));
}
};
template <typename T> template <typename T>
struct SelectFunctor<GPUDevice, T> { struct SelectFunctor<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
@ -89,10 +105,15 @@ struct BatchSelectFunctor<GPUDevice, T> {
} }
}; };
#define SELECT_FUNCTOR(T) \ #define SELECT_FUNCTOR(T) \
template struct SelectFunctor<GPUDevice, T>; \ template struct SelectFunctor<GPUDevice, T>; \
template struct SelectScalarFunctor<GPUDevice, T>; \ template struct SelectScalarFunctor<GPUDevice, T>; \
template struct BatchSelectFunctor<GPUDevice, T>; template struct BatchSelectFunctor<GPUDevice, T>; \
template struct BCastSelectFunctor<GPUDevice, T, 1>; \
template struct BCastSelectFunctor<GPUDevice, T, 2>; \
template struct BCastSelectFunctor<GPUDevice, T, 3>; \
template struct BCastSelectFunctor<GPUDevice, T, 4>; \
template struct BCastSelectFunctor<GPUDevice, T, 5>;
SELECT_FUNCTOR(bool); SELECT_FUNCTOR(bool);
SELECT_FUNCTOR(Eigen::half); SELECT_FUNCTOR(Eigen::half);

View File

@ -143,21 +143,141 @@ class SelectOp : public OpKernel {
private: private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
}; };
template <typename Device, typename T>
class SelectV2Op : public OpKernel {
public:
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
#define REGISTER_SELECT(type) \ void Compute(OpKernelContext* ctx) override {
REGISTER_KERNEL_BUILDER( \ const Tensor* cond;
Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ const Tensor* then;
SelectOp<CPUDevice, type>); const Tensor* else_;
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
// This matches the behavior of numpy.
// TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple
// 2-ary broadcast.
// Combine `then` and `else`.
BCast then_else_bcast(BCast::FromShape(then->shape()),
BCast::FromShape(else_->shape()), false);
OP_REQUIRES(ctx, then_else_bcast.IsValid(),
errors::InvalidArgument(
"then ", then->shape().DebugString(), " and else ",
else_->shape().DebugString(), " must be broadcastable"));
// Combine `cond` with `then` and `else`.
BCast bcast(
BCast::FromShape(cond->shape()),
BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())),
false);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
"condition ", cond->shape().DebugString(), ", then ",
then->shape().DebugString(), ", and else ",
else_->shape().DebugString(), " must be broadcastable"));
// Broadcast `cond`, `then` and `else` to combined shape,
// in order to obtain the reshape.
BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
BCast::FromShape(cond->shape()), false);
BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
BCast::FromShape(then->shape()), false);
BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())),
BCast::FromShape(else_->shape()), false);
OP_REQUIRES(
ctx,
cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(),
errors::InvalidArgument("condition ", cond->shape().DebugString(),
", then ", then->shape().DebugString(),
", and else ", else_->shape().DebugString(),
" must be broadcastable"));
// Combined shape should be the final shape.
OP_REQUIRES(
ctx,
cond_bcast.output_shape() == bcast.output_shape() &&
then_bcast.output_shape() == bcast.output_shape() &&
else_bcast.output_shape() == bcast.output_shape(),
errors::InvalidArgument("condition ", cond->shape().DebugString(),
", then ", then->shape().DebugString(),
", and else ", else_->shape().DebugString(),
" must be broadcastable to the same shape"));
Tensor* output = nullptr;
const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
{"t", "e"}, "output", output_shape, &output));
if (output->NumElements() == 0) {
return;
}
#define HANDLE_DIM(NDIMS) \
{ \
functor::BCastSelectFunctor<Device, T, NDIMS> func; \
func(ctx->eigen_device<Device>(), \
output->shaped<T, NDIMS>(bcast.result_shape()), \
cond->template shaped<bool, NDIMS>(cond_bcast.y_reshape()), \
then->template shaped<T, NDIMS>(then_bcast.y_reshape()), \
else_->template shaped<T, NDIMS>(else_bcast.y_reshape()), \
BCast::ToIndexArray<NDIMS>(cond_bcast.y_bcast()), \
BCast::ToIndexArray<NDIMS>(then_bcast.y_bcast()), \
BCast::ToIndexArray<NDIMS>(else_bcast.y_bcast())); \
}
const int ndims = static_cast<int>(bcast.result_shape().size());
switch (ndims) {
case 1:
HANDLE_DIM(1);
break;
case 2:
HANDLE_DIM(2);
break;
case 3:
HANDLE_DIM(3);
break;
case 4:
HANDLE_DIM(4);
break;
case 5:
HANDLE_DIM(5);
break;
default:
ctx->SetStatus(errors::Unimplemented(
"Broadcast between ", ctx->input(0).shape().DebugString(), " and ",
ctx->input(1).shape().DebugString(), " is not supported yet."));
break;
}
return;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(SelectV2Op);
};
#define REGISTER_SELECT(type) \
REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SelectOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SelectV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SelectV2Op<CPUDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_SELECT); TF_CALL_ALL_TYPES(REGISTER_SELECT);
#if GOOGLE_CUDA #if GOOGLE_CUDA
// Registration of the GPU implementations. // Registration of the GPU implementations.
#define REGISTER_SELECT_GPU(type) \ #define REGISTER_SELECT_GPU(type) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ Name("Select").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SelectOp<GPUDevice, type>); SelectOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SelectV2").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SelectV2Op<GPUDevice, type>);
REGISTER_SELECT_GPU(bool); REGISTER_SELECT_GPU(bool);
REGISTER_SELECT_GPU(Eigen::half); REGISTER_SELECT_GPU(Eigen::half);
@ -174,9 +294,12 @@ REGISTER_SELECT_GPU(complex128);
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
// Registration of the SYCL implementations. // Registration of the SYCL implementations.
#define REGISTER_SELECT_SYCL(type) \ #define REGISTER_SELECT_SYCL(type) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SelectOp<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("SelectV2").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SelectOp<SYCLDevice, type>); SelectOp<SYCLDevice, type>);
REGISTER_SELECT_SYCL(float); REGISTER_SELECT_SYCL(float);
@ -324,10 +447,35 @@ struct BatchSelectFunctor<CPUDevice, T> {
} }
}; };
template <typename Device, typename T, int NDIMS>
struct BCastSelectFunctorBase {
void operator()(const Device& d,
typename TTypes<T, NDIMS>::Tensor output_tensor,
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast) {
output_tensor.device(d) = cond_tensor.broadcast(cond_bcast)
.select(then_tensor.broadcast(then_bcast),
else_tensor.broadcast(else_bcast));
}
};
template <typename T, int NDIMS>
struct BCastSelectFunctor<CPUDevice, T, NDIMS>
: BCastSelectFunctorBase<CPUDevice, T, NDIMS> {};
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
template <typename T> template <typename T>
struct BatchSelectFunctor<SYCLDevice, T> struct BatchSelectFunctor<SYCLDevice, T>
: BatchSelectFunctorBase<SYCLDevice, T> {}; : BatchSelectFunctorBase<SYCLDevice, T> {};
template <typename T, int NDIMS>
struct BCastSelectFunctor<SYCLDevice, T, NDIMS>
: BCastSelectFunctorBase<SYCLDevice, T, NDIMS> {};
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
} // namespace functor } // namespace functor

View File

@ -1208,6 +1208,18 @@ struct BatchSelectFunctor {
typename TTypes<T>::ConstMatrix else_flat_outer_dims); typename TTypes<T>::ConstMatrix else_flat_outer_dims);
}; };
template <typename Device, typename T, int NDIMS>
struct BCastSelectFunctor {
void operator()(const Device& d,
typename TTypes<T, NDIMS>::Tensor output_tensor,
typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
typename TTypes<T, NDIMS>::ConstTensor then_tensor,
typename TTypes<T, NDIMS>::ConstTensor else_tensor,
typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
};
} // end namespace functor } // end namespace functor
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -828,6 +828,57 @@ REGISTER_OP("Select")
return Status::OK(); return Status::OK();
}); });
REGISTER_OP("SelectV2")
.Input("condition: bool")
.Input("t: T")
.Input("e: T")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
auto* handle_data_1 = c->input_handle_shapes_and_types(1);
auto* handle_data_2 = c->input_handle_shapes_and_types(2);
// Merge handle shape and dtype if applicable.
if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
const auto size = handle_data_1->size();
std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
if (size != handle_data_2->size()) {
return errors::InvalidArgument(
"Trying to merge handles pointing to different numbers of "
"tensors.");
}
for (int i = 0; i < size; ++i) {
const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
if (s1.dtype != s2.dtype) {
// TODO(apassos) resolve this in the manner of b/32476923
return errors::InvalidArgument(
"Trying to merge handles pointing to different dtypes.");
}
merged_handle_data[i].dtype = s1.dtype;
TF_RETURN_IF_ERROR(
c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
}
c->set_output_handle_shapes_and_types(0, merged_handle_data);
}
// The inputs 'cond', 'then', and 'else' must be broadcastable.
// TODO (yongtang): Consolidate 3-ary broadcast instead of
// multiple 2-ary broadcast.
ShapeHandle cond = c->input(0);
ShapeHandle then = c->input(1);
ShapeHandle else_ = c->input(2);
ShapeHandle other;
TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other));
ShapeHandle output;
TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output));
c->set_output(0, output);
return Status::OK();
});
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
REGISTER_OP("MatMul") REGISTER_OP("MatMul")

View File

@ -318,28 +318,38 @@ class LogicalOpTest(test.TestCase):
class SelectOpTest(test.TestCase): class SelectOpTest(test.TestCase):
def _compare(self, c, x, y, use_gpu): def _compare(self, fn, c, x, y, use_gpu):
np_ans = np.where(c, x, y) np_ans = np.where(c, x, y)
with test_util.device(use_gpu=use_gpu): with test_util.device(use_gpu=use_gpu):
out = array_ops.where(c, x, y) out = fn(c, x, y)
tf_ans = self.evaluate(out) tf_ans = self.evaluate(out)
self.assertAllEqual(np_ans, tf_ans) self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, out) self.assertShapeEqual(np_ans, out)
def _compareGradientX(self, c, x, y, numeric_gradient_type=None): def _compareGradientX(self,
fn,
c,
x,
y,
numeric_gradient_type=None,
x_init_value=None):
with self.cached_session(): with self.cached_session():
inx = ops.convert_to_tensor(x) inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y) iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny) out = fn(c, inx, iny)
s = list(np.shape(c)) s = list(np.shape(c))
if x_init_value is None:
x_init_value = x
if x.shape != y.shape:
x_init_value = np.broadcast_to(y, x.shape)
jacob_t, jacob_n = gradient_checker.compute_gradient( jacob_t, jacob_n = gradient_checker.compute_gradient(
inx, s, out, s, x_init_value=x) inx, s, out, s, x_init_value=x_init_value)
if numeric_gradient_type is not None: if numeric_gradient_type is not None:
xf = x.astype(numeric_gradient_type) xf = x.astype(numeric_gradient_type)
yf = y.astype(numeric_gradient_type) yf = y.astype(numeric_gradient_type)
inxf = ops.convert_to_tensor(xf) inxf = ops.convert_to_tensor(xf)
inyf = ops.convert_to_tensor(yf) inyf = ops.convert_to_tensor(yf)
outf = array_ops.where(c, inxf, inyf) outf = fn(c, inxf, inyf)
_, jacob_n = gradient_checker.compute_gradient( _, jacob_n = gradient_checker.compute_gradient(
inxf, s, outf, s, x_init_value=xf) inxf, s, outf, s, x_init_value=xf)
jacob_n = jacob_n.astype(x.dtype) jacob_n = jacob_n.astype(x.dtype)
@ -350,20 +360,20 @@ class SelectOpTest(test.TestCase):
elif x.dtype == np.float64: elif x.dtype == np.float64:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _compareGradientY(self, c, x, y, numeric_gradient_type=None): def _compareGradientY(self, fn, c, x, y, numeric_gradient_type=None):
with self.cached_session(): with self.cached_session():
inx = ops.convert_to_tensor(x) inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y) iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny) out = fn(c, inx, iny)
s = list(np.shape(c)) s = list(np.shape(c))
jacob_t, jacob_n = gradient_checker.compute_gradient( jacob_t, jacob_n = gradient_checker.compute_gradient(
iny, s, out, s, x_init_value=y, delta=1.0) iny, s, out, s, x_init_value=x, delta=1.0)
if numeric_gradient_type is not None: if numeric_gradient_type is not None:
xf = x.astype(numeric_gradient_type) xf = x.astype(numeric_gradient_type)
yf = y.astype(numeric_gradient_type) yf = y.astype(numeric_gradient_type)
inxf = ops.convert_to_tensor(xf) inxf = ops.convert_to_tensor(xf)
inyf = ops.convert_to_tensor(yf) inyf = ops.convert_to_tensor(yf)
outf = array_ops.where(c, inxf, inyf) outf = fn(c, inxf, inyf)
_, jacob_n = gradient_checker.compute_gradient( _, jacob_n = gradient_checker.compute_gradient(
inyf, s, outf, s, x_init_value=yf) inyf, s, outf, s, x_init_value=yf)
jacob_n = jacob_n.astype(x.dtype) jacob_n = jacob_n.astype(x.dtype)
@ -374,7 +384,7 @@ class SelectOpTest(test.TestCase):
elif x.dtype == np.float64: elif x.dtype == np.float64:
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def testScalar(self): def _testScalar(self, fn):
c = True c = True
x = np.random.rand(1, 3, 2) * 100 x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 3, 2) * 100 y = np.random.rand(1, 3, 2) * 100
@ -384,11 +394,58 @@ class SelectOpTest(test.TestCase):
]: ]:
xt = x.astype(t) xt = x.astype(t)
yt = y.astype(t) yt = y.astype(t)
self._compare(c, xt, yt, use_gpu=False) self._compare(fn, c, xt, yt, use_gpu=False)
if t in [np.float16, np.float32, np.float64]: if t in [np.float16, np.float32, np.float64]:
self._compare(c, xt, yt, use_gpu=True) self._compare(fn, c, xt, yt, use_gpu=True)
def testBasic(self): def testScalar(self):
self._testScalar(array_ops.where)
self._testScalar(array_ops.where_v2)
def _testScalarBroadcast(self, fn, c, x, y):
for t in [
np.float16, np.float32, np.float64, np.int32, np.int64, np.complex64,
np.complex128
]:
xt = x.astype(t)
yt = y.astype(t)
self._compare(fn, c, xt, yt, use_gpu=False)
if t in [np.float16, np.float32, np.float64]:
self._compare(fn, c, xt, yt, use_gpu=True)
def testScalarBroadcast(self):
c = True
# where_v2 only
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1, 1) * 100
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 3, 1) * 100
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1, 2) * 100
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1) * 100
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1) * 100
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 2) * 100
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(3, 2) * 100
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
def _testBasic(self, fn):
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
x = np.random.rand(1, 3, 2) * 100 x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 3, 2) * 100 y = np.random.rand(1, 3, 2) * 100
@ -398,12 +455,62 @@ class SelectOpTest(test.TestCase):
]: ]:
xt = x.astype(t) xt = x.astype(t)
yt = y.astype(t) yt = y.astype(t)
self._compare(c, xt, yt, use_gpu=False) self._compare(fn, c, xt, yt, use_gpu=False)
if t in [np.float16, np.float32, np.float64]: if t in [np.float16, np.float32, np.float64]:
self._compare(c, xt, yt, use_gpu=True) self._compare(fn, c, xt, yt, use_gpu=True)
@test_util.run_deprecated_v1 def testBasic(self):
def testGradients(self): self._testBasic(array_ops.where)
self._testBasic(array_ops.where_v2)
def _testBasicBroadcast(self, fn, c, x, y):
for t in [
np.float16, np.float32, np.float64, np.int32, np.int64, np.complex64,
np.complex128
]:
xt = x.astype(t)
yt = y.astype(t)
self._compare(fn, c, xt, yt, use_gpu=False)
if t in [np.float16, np.float32, np.float64]:
self._compare(fn, c, xt, yt, use_gpu=True)
def testBasicBroadcast(self):
c0 = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
c1 = np.random.randint(0, 2, 2).astype(np.bool).reshape(1, 1, 2)
c2 = np.random.randint(0, 2, 3).astype(np.bool).reshape(1, 3, 1)
c3 = np.random.randint(0, 2, 1).astype(np.bool).reshape(1, 1, 1)
for c in [c0, c1, c2, c3]:
# where_v2 only
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1, 1) * 100
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 3, 1) * 100
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1, 2) * 100
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1) * 100
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1) * 100
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 2) * 100
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(3, 2) * 100
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
def _testGradients(self, fn):
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
x = np.random.rand(1, 3, 2) * 100 x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 3, 2) * 100 y = np.random.rand(1, 3, 2) * 100
@ -416,14 +523,45 @@ class SelectOpTest(test.TestCase):
# care is taken with choosing the inputs and the delta. This is # care is taken with choosing the inputs and the delta. This is
# a weaker check (in particular, it does not test the op itself, # a weaker check (in particular, it does not test the op itself,
# only its gradient), but it's much better than nothing. # only its gradient), but it's much better than nothing.
self._compareGradientX(c, xt, yt, np.float) self._compareGradientX(fn, c, xt, yt, np.float)
self._compareGradientY(c, xt, yt, np.float) self._compareGradientY(fn, c, xt, yt, np.float)
else: else:
self._compareGradientX(c, xt, yt) self._compareGradientX(fn, c, xt, yt)
self._compareGradientY(c, xt, yt) self._compareGradientY(fn, c, xt, yt)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testShapeMismatch(self): def testGradients(self):
self._testGradients(array_ops.where)
self._testGradients(array_ops.where_v2)
@test_util.run_deprecated_v1
def testGradientsBroadcast(self):
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
for t in [np.float32, np.float64]:
# where_v2 only
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1, 1) * 100
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 3, 1) * 100
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1, 2) * 100
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 1) * 100
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1) * 100
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(1, 2) * 100
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(3, 2) * 100
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
def _testShapeMismatch(self, fn):
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
x = np.random.rand(1, 3, 2) * 100 x = np.random.rand(1, 3, 2) * 100
y = np.random.rand(2, 5, 3) * 100 y = np.random.rand(2, 5, 3) * 100
@ -434,10 +572,14 @@ class SelectOpTest(test.TestCase):
xt = x.astype(t) xt = x.astype(t)
yt = y.astype(t) yt = y.astype(t)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
array_ops.where(c, xt, yt) fn(c, xt, yt)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testEmptyTensor(self): def testShapeMismatch(self):
self._testShapeMismatch(array_ops.where)
self._testShapeMismatch(array_ops.where_v2)
def _testEmptyTensor(self, fn):
c = np.random.randint(0, 3, 0).astype(np.bool).reshape(1, 3, 0) c = np.random.randint(0, 3, 0).astype(np.bool).reshape(1, 3, 0)
x = np.random.rand(1, 3, 0) * 100 x = np.random.rand(1, 3, 0) * 100
y = np.random.rand(1, 3, 0) * 100 y = np.random.rand(1, 3, 0) * 100
@ -445,20 +587,29 @@ class SelectOpTest(test.TestCase):
with self.cached_session(): with self.cached_session():
xt = x.astype(np.float32) xt = x.astype(np.float32)
yt = y.astype(np.float32) yt = y.astype(np.float32)
z = array_ops.where(c, xt, yt).eval() z = fn(c, xt, yt).eval()
self.assertAllEqual(z_expected, z) self.assertAllEqual(z_expected, z)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testNan(self): def testEmptyTensor(self):
"""Verify that nans don't propagate where they shouldn't.""" self._testEmptyTensor(array_ops.where)
self._testEmptyTensor(array_ops.where_v2)
def _testNan(self, fn):
with self.cached_session(): with self.cached_session():
for c in False, True: for c in False, True:
for a in 7.0, np.nan: for a in 7.0, np.nan:
for b in 5.0, np.nan: for b in 5.0, np.nan:
x = array_ops.where(c, a, b).eval() x = fn(c, a, b).eval()
y = a if c else b y = a if c else b
self.assertEqual(np.isnan(x), np.isnan(y)) self.assertEqual(np.isnan(x), np.isnan(y))
@test_util.run_deprecated_v1
def testNan(self):
"""Verify that nans don't propagate where they shouldn't."""
self._testNan(array_ops.where)
self._testNan(array_ops.where_v2)
class BatchSelectOpTest(test.TestCase): class BatchSelectOpTest(test.TestCase):
"""Test broadcasting of Select when 'c' is a vec and 't' &'e' are rank2+.""" """Test broadcasting of Select when 'c' is a vec and 't' &'e' are rank2+."""

View File

@ -37,10 +37,10 @@ from tensorflow.python.platform import test
class WhereOpTest(test.TestCase): class WhereOpTest(test.TestCase):
def _testWhere(self, x, truth, expected_err_re=None): def _testWhere(self, x, truth, expected_err_re=None, fn=array_ops.where):
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
ans = array_ops.where(x) ans = fn(x)
self.assertEqual([None, x.ndim], ans.get_shape().as_list()) self.assertTrue(ans.get_shape().is_compatible_with([None, x.ndim]))
if expected_err_re is None: if expected_err_re is None:
tf_ans = self.evaluate(ans) tf_ans = self.evaluate(ans)
self.assertAllClose(tf_ans, truth, atol=1e-10) self.assertAllClose(tf_ans, truth, atol=1e-10)
@ -48,44 +48,40 @@ class WhereOpTest(test.TestCase):
with self.assertRaisesOpError(expected_err_re): with self.assertRaisesOpError(expected_err_re):
self.evaluate(ans) self.evaluate(ans)
def testWrongNumbers(self): def _testWrongNumbers(self, fn=array_ops.where):
with self.session(use_gpu=True): with self.session(use_gpu=True):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
array_ops.where([False, True], [1, 2], None) fn([False, True], [1, 2], None)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
array_ops.where([False, True], None, [1, 2]) fn([False, True], None, [1, 2])
@test_util.run_deprecated_v1 def _testBasicVec(self, fn=array_ops.where):
def testBasicVec(self):
x = np.asarray([True, False]) x = np.asarray([True, False])
truth = np.asarray([[0]], dtype=np.int64) truth = np.asarray([[0]], dtype=np.int64)
self._testWhere(x, truth) self._testWhere(x, truth, None, fn)
x = np.asarray([False, True, False]) x = np.asarray([False, True, False])
truth = np.asarray([[1]], dtype=np.int64) truth = np.asarray([[1]], dtype=np.int64)
self._testWhere(x, truth) self._testWhere(x, truth, None, fn)
x = np.asarray([False, False, True, False, True]) x = np.asarray([False, False, True, False, True])
truth = np.asarray([[2], [4]], dtype=np.int64) truth = np.asarray([[2], [4]], dtype=np.int64)
self._testWhere(x, truth) self._testWhere(x, truth, None, fn)
@test_util.run_deprecated_v1 def _testRandomVec(self, fn=array_ops.where):
def testRandomVec(self):
x = np.random.rand(1000000) > 0.5 x = np.random.rand(1000000) > 0.5
truth = np.vstack([np.where(x)[0].astype(np.int64)]).T truth = np.vstack([np.where(x)[0].astype(np.int64)]).T
self._testWhere(x, truth) self._testWhere(x, truth, None, fn)
@test_util.run_deprecated_v1 def _testBasicMat(self, fn=array_ops.where):
def testBasicMat(self):
x = np.asarray([[True, False], [True, False]]) x = np.asarray([[True, False], [True, False]])
# Ensure RowMajor mode # Ensure RowMajor mode
truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64) truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64)
self._testWhere(x, truth) self._testWhere(x, truth, None, fn)
@test_util.run_deprecated_v1 def _testBasic3Tensor(self, fn=array_ops.where):
def testBasic3Tensor(self):
x = np.asarray([[[True, False], [True, False]], x = np.asarray([[[True, False], [True, False]],
[[False, True], [False, True]], [[False, True], [False, True]],
[[False, False], [False, True]]]) [[False, False], [False, True]]])
@ -94,15 +90,41 @@ class WhereOpTest(test.TestCase):
truth = np.asarray( truth = np.asarray(
[[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64) [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], dtype=np.int64)
self._testWhere(x, truth) self._testWhere(x, truth, None, fn)
def _testRandom(self, dtype, expected_err_re=None): def _testRandom(self, dtype, expected_err_re=None, fn=array_ops.where):
shape = [127, 33, 53] shape = [127, 33, 53]
x = np.random.randn(*shape) + 1j * np.random.randn(*shape) x = np.random.randn(*shape) + 1j * np.random.randn(*shape)
x = (np.random.randn(*shape) > 0).astype(dtype) x = (np.random.randn(*shape) > 0).astype(dtype)
truth = np.where(np.abs(x) > 0) # Tuples of indices by axis. truth = np.where(np.abs(x) > 0) # Tuples of indices by axis.
truth = np.vstack(truth).T # Convert to [num_true, indices]. truth = np.vstack(truth).T # Convert to [num_true, indices].
self._testWhere(x, truth, expected_err_re) self._testWhere(x, truth, expected_err_re, fn)
def _testThreeArgument(self, fn=array_ops.where):
x = np.array([[-2, 3, -1], [1, -3, -3]])
np_val = np.where(x > 0, x * x, -x)
with self.test_session(use_gpu=True):
tf_val = self.evaluate(fn(constant_op.constant(x) > 0, x * x, -x))
self.assertAllEqual(tf_val, np_val)
def testWrongNumbers(self):
self._testWrongNumbers()
@test_util.run_deprecated_v1
def testBasicVec(self):
self._testBasicVec()
@test_util.run_deprecated_v1
def testRandomVec(self):
self._testRandomVec()
@test_util.run_deprecated_v1
def testBasicMat(self):
self._testBasicMat()
@test_util.run_deprecated_v1
def testBasic3Tensor(self):
self._testBasic3Tensor()
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testRandomBool(self): def testRandomBool(self):
@ -146,12 +168,95 @@ class WhereOpTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testThreeArgument(self): def testThreeArgument(self):
x = np.array([[-2, 3, -1], [1, -3, -3]]) self._testThreeArgument()
np_val = np.where(x > 0, x * x, -x)
with self.session(use_gpu=True): def testV2WrongNumbers(self):
tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval() self._testWrongNumbers(array_ops.where_v2)
def testV2BasicVec(self):
self._testBasicVec(array_ops.where_v2)
def testV2RandomVec(self):
self._testRandomVec(array_ops.where_v2)
def testV2BasicMat(self):
self._testBasicMat(array_ops.where_v2)
def testV2Basic3Tensor(self):
self._testBasic3Tensor(array_ops.where_v2)
def testV2RandomBool(self):
self._testRandom(np.bool, None, array_ops.where_v2)
def testV2RandomInt32(self):
self._testRandom(np.int32, None, array_ops.where_v2)
def testV2RandomInt64(self):
self._testRandom(np.int64, None, array_ops.where_v2)
def testV2RandomFloat(self):
self._testRandom(np.float32, None, array_ops.where_v2)
def testV2RandomDouble(self):
self._testRandom(np.float64, None, array_ops.where_v2)
def testV2RandomComplex64(self):
self._testRandom(np.complex64, None, array_ops.where_v2)
def testV2RandomComplex128(self):
self._testRandom(np.complex128, None, array_ops.where_v2)
def testV2RandomUint8(self):
self._testRandom(np.uint8, None, array_ops.where_v2)
def testV2RandomInt8(self):
self._testRandom(np.int8, None, array_ops.where_v2)
def testV2RandomInt16(self):
self._testRandom(np.int16, None, array_ops.where_v2)
def testV2ThreeArgument(self):
self._testThreeArgument(array_ops.where_v2)
def testV2Broadcasting(self):
f = np.random.normal(0, 1, (3, 5, 1, 1))
x = np.zeros((7, 11))
y = np.ones((7, 11))
np_val = np.where(f < 0, x, y)
with self.test_session(use_gpu=True):
tf_val = self.evaluate(
array_ops.where_v2(constant_op.constant(f) < 0, x, y))
self.assertAllEqual(tf_val, np_val) self.assertAllEqual(tf_val, np_val)
def testV2ScalarBroadcasting(self):
x = np.zeros((7, 11))
y = np.ones((7, 11))
np_val = np.where(True, x, y)
with self.test_session(use_gpu=True):
tf_val = self.evaluate(
array_ops.where_v2(
constant_op.constant(True, dtype=dtypes.bool), x, y))
self.assertAllEqual(tf_val, np_val)
def testV2VectorBroadcasting(self):
x = np.zeros(7)
y = np.ones(7)
np_val = np.where([True], x, y)
with self.test_session(use_gpu=True):
tf_val = self.evaluate(
array_ops.where_v2(
constant_op.constant([True], dtype=dtypes.bool), x, y))
self.assertAllEqual(tf_val, np_val)
def testV2PredBroadcasting(self):
pred = np.array([1, 0, 0]).reshape((3, 1))
x = np.random.randn(3, 4)
y = np.random.randn(3, 4)
np_val = np.where(pred, x, y)
with self.test_session(use_gpu=True):
tf_val = self.evaluate(array_ops.where_v2(pred, x, y))
self.assertAllClose(tf_val, np_val)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testBatchSelect(self): def testBatchSelect(self):
x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192] x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192]

View File

@ -3214,7 +3214,11 @@ def squeeze_v2(input, axis=None, name=None):
return squeeze(input, axis, name) return squeeze(input, axis, name)
@tf_export("where") @tf_export(v1=["where"])
@deprecation.deprecated(
date=None,
instructions="Use tf.where in 2.0, "
"which has the same broadcast rule as np.where")
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def where(condition, x=None, y=None, name=None): def where(condition, x=None, y=None, name=None):
"""Return the elements, either from `x` or `y`, depending on the `condition`. """Return the elements, either from `x` or `y`, depending on the `condition`.
@ -3268,6 +3272,48 @@ def where(condition, x=None, y=None, name=None):
raise ValueError("x and y must both be non-None or both be None.") raise ValueError("x and y must both be non-None or both be None.")
@tf_export("where", v1=["where_v2"])
def where_v2(condition, x=None, y=None, name=None):
"""Return the elements, either from `x` or `y`, depending on the `condition`.
If both `x` and `y` are None, then this operation returns the coordinates of
true elements of `condition`. The coordinates are returned in a 2-D tensor
where the first dimension (rows) represents the number of true elements, and
the second dimension (columns) represents the coordinates of the true
elements. Keep in mind, the shape of the output tensor can vary depending on
how many true values there are in input. Indices are output in row-major
order.
If both non-None, `condition`, `x` and `y` must be broadcastable to the same
shape.
The `condition` tensor acts as a mask that chooses, based on the value at each
element, whether the corresponding element / row in the output should be taken
from `x` (if true) or `y` (if false).
Args:
condition: A `Tensor` of type `bool`
x: A Tensor which is of the same type as `y`, and may be broadcastable with
`condition` and `y`.
y: A Tensor which is of the same type as `x`, and may be broadcastable with
`condition` and `x`.
name: A name of the operation (optional).
Returns:
A `Tensor` with the same type as `x` and `y`, and shape that
is broadcasted from `condition`, `x`, and `y`, if `x`, `y` are non-None.
A `Tensor` with shape `(num_true, dim_size(condition))`.
Raises:
ValueError: When exactly one of `x` or `y` is non-None.
"""
if x is None and y is None:
with ops.name_scope(name, "Where", [condition]) as name:
condition = ops.convert_to_tensor(
condition, preferred_dtype=dtypes.bool, name="condition")
return gen_array_ops.where(condition=condition, name=name)
elif x is not None and y is not None:
return gen_math_ops.select_v2(condition=condition, t=x, e=y, name=name)
else:
raise ValueError("x and y must both be non-None or both be None.")
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
@tf_export(v1=["reverse_sequence"]) @tf_export(v1=["reverse_sequence"])
@deprecation.deprecated_args(None, @deprecation.deprecated_args(None,

View File

@ -1309,6 +1309,39 @@ def _SelectGrad(op, grad):
c, zeros, grad)) c, zeros, grad))
@ops.RegisterGradient("SelectV2")
def _SelectGradV2(op, grad):
c = op.inputs[0]
x = op.inputs[1]
y = op.inputs[2]
zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype)
gx = array_ops.where_v2(c, grad, zeros)
gx_shape = array_ops.shape(gx)
x_shape = array_ops.shape(x)
rankdiff_x = array_ops.rank(gx) - array_ops.rank(x)
# Reduce away broadcasted leading dims.
gx = math_ops.reduce_sum(gx, axis=math_ops.range(rankdiff_x))
# Reduce but keep x's 1-valued dims which were broadcast.
axis = array_ops.where_v2(gx_shape[rankdiff_x:] > x_shape)
# tf.where returns 2D so squeeze.
axis = array_ops.squeeze(axis)
gx = math_ops.reduce_sum(gx, keepdims=True, axis=axis)
gy = array_ops.where_v2(c, zeros, grad)
gy_shape = array_ops.shape(gy)
y_shape = array_ops.shape(y)
rankdiff_y = array_ops.rank(gy) - array_ops.rank(y)
# Reduce away broadcasted leading dims.
gy = math_ops.reduce_sum(gy, axis=math_ops.range(rankdiff_y))
# Reduce but keep y's 1-valued dims which were broadcast.
axis = array_ops.where_v2(gy_shape[rankdiff_y:] > y_shape)
# tf.where returns 2D so squeeze.
axis = array_ops.squeeze(axis)
gy = math_ops.reduce_sum(gy, keepdims=True, axis=axis)
return (None, gx, gy)
def _MatMulGradAgainstFirstOnly(op, grad): def _MatMulGradAgainstFirstOnly(op, grad):
"""Gradient for MatMul, only for the first input.""" """Gradient for MatMul, only for the first input."""
t_a = op.get_attr("transpose_a") t_a = op.get_attr("transpose_a")

View File

@ -2436,6 +2436,10 @@ tf_module {
name: "where" name: "where"
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
member_method {
name: "where_v2"
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method { member_method {
name: "while_loop" name: "while_loop"
argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], " argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], "

View File

@ -3352,6 +3352,10 @@ tf_module {
name: "Select" name: "Select"
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "SelectV2"
argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "SelfAdjointEig" name: "SelfAdjointEig"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -3352,6 +3352,10 @@ tf_module {
name: "Select" name: "Select"
argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "SelectV2"
argspec: "args=[\'condition\', \'t\', \'e\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "SelfAdjointEig" name: "SelfAdjointEig"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1543,6 +1543,10 @@ renames = {
'tf.compat.v1.variables_initializer', 'tf.compat.v1.variables_initializer',
'tf.verify_tensor_all_finite': 'tf.verify_tensor_all_finite':
'tf.compat.v1.verify_tensor_all_finite', 'tf.compat.v1.verify_tensor_all_finite',
'tf.where':
'tf.compat.v1.where',
'tf.where_v2':
'tf.compat.v2.where',
'tf.wrap_function': 'tf.wrap_function':
'tf.compat.v1.wrap_function', 'tf.compat.v1.wrap_function',
'tf.write_file': 'tf.write_file':