Add incompatible_shape_error attribute to equal op
When tensor equality is enabled, if there is an incompatible shape we currently throw and exception. Ideally we'd like to return False when calling __eq__ and True when calling __ne__. We thus modify the Equal and NotEqual ops to return a boolean upon a shape incompatibility. Due to this change the shape inference logic needs to be changed to either return a scalar bool if the shapes are incompatible, or else return an unknown shape to allow for either a boolean Tensor or scalar to be returned. Note the behavior of tf.math.equal & tf.math.not_equal is unchanged as they both use optimistic shape inference logic when dealing with unknown dimensions which allows for more efficient graphs rather than inserting Rank operations. This distinction between __eq__ & tf.math.equal is also found in numpy and as a result the tf.debugging.assert_equal and tf.debugging.assert_none_equal APIs needed to be change to utilize the numpy operations. PiperOrigin-RevId: 267466043
This commit is contained in:
parent
792abd2eaf
commit
e0e1efbe08
@ -1,9 +1,4 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "Equal"
|
graph_op_name: "Equal"
|
||||||
endpoint {
|
visibility: HIDDEN
|
||||||
name: "math.equal"
|
|
||||||
}
|
|
||||||
endpoint {
|
|
||||||
name: "equal"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,4 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "NotEqual"
|
graph_op_name: "NotEqual"
|
||||||
endpoint {
|
visibility: HIDDEN
|
||||||
name: "math.not_equal"
|
|
||||||
}
|
|
||||||
endpoint {
|
|
||||||
name: "not_equal"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -366,7 +366,8 @@ Status EinsumShape(shape_inference::InferenceContext* c) {
|
|||||||
output_bcast_shape = input_bcast_shapes[0];
|
output_bcast_shape = input_bcast_shapes[0];
|
||||||
} else if (input_bcast_shapes.size() == 2) {
|
} else if (input_bcast_shapes.size() == 2) {
|
||||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||||
c, input_bcast_shapes[0], input_bcast_shapes[1], &output_bcast_shape));
|
c, input_bcast_shapes[0], input_bcast_shapes[1], true,
|
||||||
|
&output_bcast_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool output_has_ellipsis = false;
|
bool output_has_ellipsis = false;
|
||||||
@ -441,7 +442,7 @@ Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
|
|||||||
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
|
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||||
c, a_batch_shape, b_batch_shape, &output_batch_shape));
|
c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
|
||||||
|
|
||||||
ShapeHandle output_shape;
|
ShapeHandle output_shape;
|
||||||
TF_RETURN_IF_ERROR(c->Concatenate(
|
TF_RETURN_IF_ERROR(c->Concatenate(
|
||||||
@ -1633,6 +1634,7 @@ Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
|
|||||||
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||||
ShapeHandle shape_x,
|
ShapeHandle shape_x,
|
||||||
ShapeHandle shape_y,
|
ShapeHandle shape_y,
|
||||||
|
bool incompatible_shape_error,
|
||||||
ShapeHandle* out) {
|
ShapeHandle* out) {
|
||||||
CHECK_NOTNULL(out);
|
CHECK_NOTNULL(out);
|
||||||
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
|
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
|
||||||
@ -1666,8 +1668,16 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
|||||||
// or the same as the known dim.
|
// or the same as the known dim.
|
||||||
// - If either dimension is 1, the other dimension is the output.
|
// - If either dimension is 1, the other dimension is the output.
|
||||||
if (c->Value(dim_x) > 1) {
|
if (c->Value(dim_x) > 1) {
|
||||||
|
if (!incompatible_shape_error) {
|
||||||
|
*out = c->UnknownShape();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
dims.push_back(dim_x);
|
dims.push_back(dim_x);
|
||||||
} else if (c->Value(dim_y) > 1) {
|
} else if (c->Value(dim_y) > 1) {
|
||||||
|
if (!incompatible_shape_error) {
|
||||||
|
*out = c->UnknownShape();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
dims.push_back(dim_y);
|
dims.push_back(dim_y);
|
||||||
} else if (c->Value(dim_x) == 1) {
|
} else if (c->Value(dim_x) == 1) {
|
||||||
dims.push_back(dim_y);
|
dims.push_back(dim_y);
|
||||||
@ -1676,6 +1686,10 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
|||||||
} else if (dim_y.SameHandle(dim_x)) {
|
} else if (dim_y.SameHandle(dim_x)) {
|
||||||
dims.push_back(dim_x);
|
dims.push_back(dim_x);
|
||||||
} else {
|
} else {
|
||||||
|
if (!incompatible_shape_error) {
|
||||||
|
*out = c->UnknownShape();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
dims.push_back(c->UnknownDim());
|
dims.push_back(c->UnknownDim());
|
||||||
}
|
}
|
||||||
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
|
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
|
||||||
@ -1689,7 +1703,14 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
DimensionHandle dim;
|
DimensionHandle dim;
|
||||||
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
|
Status s = c->Merge(dim_x, dim_y, &dim);
|
||||||
|
if (!s.ok()) {
|
||||||
|
if (!incompatible_shape_error) {
|
||||||
|
*out = c->MakeShape({});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
dims.push_back(dim);
|
dims.push_back(dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -306,6 +306,7 @@ Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
|
|||||||
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||||
ShapeHandle shape_x,
|
ShapeHandle shape_x,
|
||||||
ShapeHandle shape_y,
|
ShapeHandle shape_y,
|
||||||
|
bool incompatible_shape_error,
|
||||||
ShapeHandle* out);
|
ShapeHandle* out);
|
||||||
|
|
||||||
// Shape function for binary operators that broadcast their inputs
|
// Shape function for binary operators that broadcast their inputs
|
||||||
@ -313,8 +314,8 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
|||||||
inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
|
inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
|
||||||
int output_index) {
|
int output_index) {
|
||||||
ShapeHandle out;
|
ShapeHandle out;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||||
BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
|
c, c->input(0), c->input(1), true, &out));
|
||||||
c->set_output(output_index, out);
|
c->set_output(output_index, out);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -57,11 +57,23 @@ BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
|
|||||||
in1(ctx->input(1)),
|
in1(ctx->input(1)),
|
||||||
bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) {
|
bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) {
|
||||||
if (!bcast.IsValid()) {
|
if (!bcast.IsValid()) {
|
||||||
|
bool incompatible_shape_error;
|
||||||
|
bool has_attr =
|
||||||
|
TryGetNodeAttr(ctx->op_kernel().def(), "incompatible_shape_error",
|
||||||
|
&(incompatible_shape_error));
|
||||||
|
if (has_attr && !incompatible_shape_error) {
|
||||||
|
const string& op = ctx->op_kernel().type_string();
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
|
||||||
|
result = (op == "NotEqual");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ctx->SetStatus(errors::InvalidArgument(
|
ctx->SetStatus(errors::InvalidArgument(
|
||||||
"Incompatible shapes: ", in0.shape().DebugString(), " vs. ",
|
"Incompatible shapes: ", in0.shape().DebugString(), " vs. ",
|
||||||
in1.shape().DebugString()));
|
in1.shape().DebugString()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
|
const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
|
||||||
out_num_elements = output_shape.num_elements();
|
out_num_elements = output_shape.num_elements();
|
||||||
in0_num_elements = in0.NumElements();
|
in0_num_elements = in0.NumElements();
|
||||||
|
@ -26,13 +26,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
|
#include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
|
||||||
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/bcast.h"
|
#include "tensorflow/core/util/bcast.h"
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ class BinaryOpShared : public OpKernel {
|
|||||||
// in-place computation.
|
// in-place computation.
|
||||||
// Caller must check ctx->status() upon return for non-ok status.
|
// Caller must check ctx->status() upon return for non-ok status.
|
||||||
// If ctx->status().ok() is true, then out is guaranteed to be allocated.
|
// If ctx->status().ok() is true, then out is guaranteed to be allocated.
|
||||||
BinaryOpState(OpKernelContext* ctx);
|
explicit BinaryOpState(OpKernelContext* ctx);
|
||||||
|
|
||||||
const Tensor& in0;
|
const Tensor& in0;
|
||||||
const Tensor& in1;
|
const Tensor& in1;
|
||||||
@ -69,6 +69,7 @@ class BinaryOpShared : public OpKernel {
|
|||||||
int64 in1_num_elements;
|
int64 in1_num_elements;
|
||||||
|
|
||||||
int ndims;
|
int ndims;
|
||||||
|
bool result;
|
||||||
};
|
};
|
||||||
|
|
||||||
void SetUnimplementedError(OpKernelContext* ctx);
|
void SetUnimplementedError(OpKernelContext* ctx);
|
||||||
@ -91,16 +92,29 @@ class BinaryOp : public BinaryOpShared {
|
|||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
// 'state': Shared helper not dependent on T to reduce code size
|
// 'state': Shared helper not dependent on T to reduce code size
|
||||||
BinaryOpState state(ctx);
|
BinaryOpState state(ctx);
|
||||||
if (!ctx->status().ok()) return;
|
auto& bcast = state.bcast;
|
||||||
|
const Device& eigen_device = ctx->eigen_device<Device>();
|
||||||
Tensor* out = state.out;
|
Tensor* out = state.out;
|
||||||
BCast* bcast = &state.bcast;
|
if (!bcast.IsValid()) {
|
||||||
|
if (ctx->status().ok()) {
|
||||||
|
if (state.result) {
|
||||||
|
functor::SetOneFunctor<Device, bool>()(eigen_device,
|
||||||
|
out->flat<bool>());
|
||||||
|
} else {
|
||||||
|
functor::SetZeroFunctor<Device, bool>()(eigen_device,
|
||||||
|
out->flat<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto& in0 = state.in0;
|
auto& in0 = state.in0;
|
||||||
auto& in1 = state.in1;
|
auto& in1 = state.in1;
|
||||||
if (state.out_num_elements == 0) {
|
if (state.out_num_elements == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ndims = state.ndims;
|
const int ndims = state.ndims;
|
||||||
const Device& eigen_device = ctx->eigen_device<Device>();
|
|
||||||
bool error = false;
|
bool error = false;
|
||||||
bool* const error_ptr = Functor::has_errors ? &error : nullptr;
|
bool* const error_ptr = Functor::has_errors ? &error : nullptr;
|
||||||
if (ndims <= 1) {
|
if (ndims <= 1) {
|
||||||
@ -122,32 +136,32 @@ class BinaryOp : public BinaryOpShared {
|
|||||||
}
|
}
|
||||||
} else if (ndims == 2) {
|
} else if (ndims == 2) {
|
||||||
functor::BinaryFunctor<Device, Functor, 2>().BCast(
|
functor::BinaryFunctor<Device, Functor, 2>().BCast(
|
||||||
eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
|
eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
|
||||||
in0.template shaped<Tin, 2>(bcast->x_reshape()),
|
in0.template shaped<Tin, 2>(bcast.x_reshape()),
|
||||||
BCast::ToIndexArray<2>(bcast->x_bcast()),
|
BCast::ToIndexArray<2>(bcast.x_bcast()),
|
||||||
in1.template shaped<Tin, 2>(bcast->y_reshape()),
|
in1.template shaped<Tin, 2>(bcast.y_reshape()),
|
||||||
BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
|
BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
|
||||||
} else if (ndims == 3) {
|
} else if (ndims == 3) {
|
||||||
functor::BinaryFunctor<Device, Functor, 3>().BCast(
|
functor::BinaryFunctor<Device, Functor, 3>().BCast(
|
||||||
eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
|
eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
|
||||||
in0.template shaped<Tin, 3>(bcast->x_reshape()),
|
in0.template shaped<Tin, 3>(bcast.x_reshape()),
|
||||||
BCast::ToIndexArray<3>(bcast->x_bcast()),
|
BCast::ToIndexArray<3>(bcast.x_bcast()),
|
||||||
in1.template shaped<Tin, 3>(bcast->y_reshape()),
|
in1.template shaped<Tin, 3>(bcast.y_reshape()),
|
||||||
BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
|
BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
|
||||||
} else if (ndims == 4) {
|
} else if (ndims == 4) {
|
||||||
functor::BinaryFunctor<Device, Functor, 4>().BCast(
|
functor::BinaryFunctor<Device, Functor, 4>().BCast(
|
||||||
eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
|
eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
|
||||||
in0.template shaped<Tin, 4>(bcast->x_reshape()),
|
in0.template shaped<Tin, 4>(bcast.x_reshape()),
|
||||||
BCast::ToIndexArray<4>(bcast->x_bcast()),
|
BCast::ToIndexArray<4>(bcast.x_bcast()),
|
||||||
in1.template shaped<Tin, 4>(bcast->y_reshape()),
|
in1.template shaped<Tin, 4>(bcast.y_reshape()),
|
||||||
BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
|
BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
|
||||||
} else if (ndims == 5) {
|
} else if (ndims == 5) {
|
||||||
functor::BinaryFunctor<Device, Functor, 5>().BCast(
|
functor::BinaryFunctor<Device, Functor, 5>().BCast(
|
||||||
eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
|
eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
|
||||||
in0.template shaped<Tin, 5>(bcast->x_reshape()),
|
in0.template shaped<Tin, 5>(bcast.x_reshape()),
|
||||||
BCast::ToIndexArray<5>(bcast->x_bcast()),
|
BCast::ToIndexArray<5>(bcast.x_bcast()),
|
||||||
in1.template shaped<Tin, 5>(bcast->y_reshape()),
|
in1.template shaped<Tin, 5>(bcast.y_reshape()),
|
||||||
BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
|
BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
|
||||||
} else {
|
} else {
|
||||||
SetUnimplementedError(ctx);
|
SetUnimplementedError(ctx);
|
||||||
}
|
}
|
||||||
|
@ -700,7 +700,19 @@ REGISTER_OP("GreaterEqual").COMPARISON();
|
|||||||
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
|
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
|
||||||
"int64, complex64, quint8, qint8, qint32, string, bool, " \
|
"int64, complex64, quint8, qint8, qint32, string, bool, " \
|
||||||
"complex128}") \
|
"complex128}") \
|
||||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
.Attr("incompatible_shape_error: bool = true") \
|
||||||
|
.SetShapeFn([](InferenceContext* c) { \
|
||||||
|
ShapeHandle x = c->input(0); \
|
||||||
|
ShapeHandle y = c->input(1); \
|
||||||
|
ShapeHandle output; \
|
||||||
|
bool incompatible_shape_error; \
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \
|
||||||
|
&incompatible_shape_error)); \
|
||||||
|
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \
|
||||||
|
c, x, y, incompatible_shape_error, &output)); \
|
||||||
|
c->set_output(0, output); \
|
||||||
|
return Status::OK(); \
|
||||||
|
})
|
||||||
|
|
||||||
REGISTER_OP("Equal").EQUALITY_COMPARISON();
|
REGISTER_OP("Equal").EQUALITY_COMPARISON();
|
||||||
|
|
||||||
@ -877,10 +889,10 @@ REGISTER_OP("SelectV2")
|
|||||||
ShapeHandle else_ = c->input(2);
|
ShapeHandle else_ = c->input(2);
|
||||||
ShapeHandle other;
|
ShapeHandle other;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other));
|
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other));
|
||||||
ShapeHandle output;
|
ShapeHandle output;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output));
|
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output));
|
||||||
c->set_output(0, output);
|
c->set_output(0, output);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
@ -110,28 +110,17 @@ TEST(MathOpsTest, Segment_ShapeFn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
|
TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
|
||||||
for (const auto* op_name : {"Add", "Complex",
|
auto test_shapes = [&](ShapeInferenceTestOp& op,
|
||||||
"Div", "Equal",
|
bool incompatible_shape_error) {
|
||||||
"Greater", "GreaterEqual",
|
|
||||||
"Igamma", "Igammac",
|
|
||||||
"Zeta", "Polygamma",
|
|
||||||
"Less", "LessEqual",
|
|
||||||
"LogicalAnd", "LogicalOr",
|
|
||||||
"Maximum", "Minimum",
|
|
||||||
"Mod", "Mul",
|
|
||||||
"NotEqual", "Pow",
|
|
||||||
"Sub", "SquaredDifference",
|
|
||||||
"DivNoNan"}) {
|
|
||||||
ShapeInferenceTestOp op(op_name);
|
|
||||||
INFER_OK(op, "?;?", "?");
|
INFER_OK(op, "?;?", "?");
|
||||||
INFER_OK(op, "[1,2];?", "?");
|
INFER_OK(op, "[1,2];?", "?");
|
||||||
INFER_OK(op, "?;[1,2]", "?");
|
INFER_OK(op, "?;[1,2]", "?");
|
||||||
|
|
||||||
INFER_OK(op, "[?];[1]", "[d0_0]");
|
INFER_OK(op, "[?];[1]", "[d0_0]");
|
||||||
INFER_OK(op, "[1];[?]", "[d1_0]");
|
INFER_OK(op, "[1];[?]", "[d1_0]");
|
||||||
INFER_OK(op, "[?];[2]", "[d1_0]");
|
INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?");
|
||||||
INFER_OK(op, "[2];[?]", "[d0_0]");
|
INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
|
||||||
INFER_OK(op, "[?];[?]", "[?]");
|
INFER_OK(op, "[?];[?]", incompatible_shape_error ? "[?]" : "?");
|
||||||
INFER_OK(op, "[];[?]", "[d1_0]");
|
INFER_OK(op, "[];[?]", "[d1_0]");
|
||||||
INFER_OK(op, "[?];[]", "[d0_0]");
|
INFER_OK(op, "[?];[]", "[d0_0]");
|
||||||
|
|
||||||
@ -144,7 +133,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
|
|||||||
INFER_OK(op, "[1];[2]", "[d1_0]");
|
INFER_OK(op, "[1];[2]", "[d1_0]");
|
||||||
INFER_OK(op, "[2];[1]", "[d0_0]");
|
INFER_OK(op, "[2];[1]", "[d0_0]");
|
||||||
INFER_OK(op, "[2];[]", "[d0_0]");
|
INFER_OK(op, "[2];[]", "[d0_0]");
|
||||||
INFER_OK(op, "[2];[?]", "[d0_0]");
|
INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
|
||||||
|
|
||||||
INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
|
INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
|
||||||
INFER_OK(op, "[];[0]", "[d1_0]");
|
INFER_OK(op, "[];[0]", "[d1_0]");
|
||||||
@ -152,14 +141,46 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
|
|||||||
INFER_OK(op, "[0];[1]", "[d0_0]");
|
INFER_OK(op, "[0];[1]", "[d0_0]");
|
||||||
INFER_OK(op, "[0];[]", "[d0_0]");
|
INFER_OK(op, "[0];[]", "[d0_0]");
|
||||||
|
|
||||||
INFER_OK(op, "[2];[?,?]", "[d1_0,d0_0]");
|
INFER_OK(op, "[2];[?,?]", incompatible_shape_error ? "[d1_0,d0_0]" : "?");
|
||||||
INFER_OK(op, "[2,2];[?,?,?]", "[d1_0,d0_0,d0_1]");
|
INFER_OK(op, "[2,2];[?,?,?]",
|
||||||
|
incompatible_shape_error ? "[d1_0,d0_0,d0_1]" : "?");
|
||||||
|
|
||||||
// Multiple dimension cases (same test cases, switching x and y).
|
// Multiple dimension cases (same test cases, switching x and y).
|
||||||
INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
|
INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
|
||||||
"[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
|
incompatible_shape_error ? "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"
|
||||||
|
: "?");
|
||||||
INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
|
INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
|
||||||
"[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
|
incompatible_shape_error ? "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"
|
||||||
|
: "?");
|
||||||
|
|
||||||
|
if (incompatible_shape_error) {
|
||||||
|
INFER_ERROR("Dimensions must be equal", op, "[2];[3]");
|
||||||
|
} else {
|
||||||
|
INFER_OK(op, "[2];[3]", "[]");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (string op_name : {"Add", "Complex",
|
||||||
|
"Div", "Equal",
|
||||||
|
"Greater", "GreaterEqual",
|
||||||
|
"Igamma", "Igammac",
|
||||||
|
"Zeta", "Polygamma",
|
||||||
|
"Less", "LessEqual",
|
||||||
|
"LogicalAnd", "LogicalOr",
|
||||||
|
"Maximum", "Minimum",
|
||||||
|
"Mod", "Mul",
|
||||||
|
"NotEqual", "Pow",
|
||||||
|
"Sub", "SquaredDifference",
|
||||||
|
"DivNoNan"}) {
|
||||||
|
ShapeInferenceTestOp op(op_name);
|
||||||
|
AddNodeAttr("incompatible_shape_error", true, &op.node_def);
|
||||||
|
test_shapes(op, true);
|
||||||
|
|
||||||
|
if ((op_name == "Equal") || (op_name == "NotEqual")) {
|
||||||
|
ShapeInferenceTestOp op(op_name);
|
||||||
|
AddNodeAttr("incompatible_shape_error", false, &op.node_def);
|
||||||
|
test_shapes(op, false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -302,7 +303,11 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
bool(tf_a == tf_d)
|
bool(tf_a == tf_d)
|
||||||
self.assertAllEqual(tf_a == tf_d, [[True, False], [True, False]])
|
self.assertAllEqual(tf_a == tf_d, [[True, False], [True, False]])
|
||||||
# TODO(b/120678848): If shapes do not match we should instead return False
|
if compat.forward_compatible(2019, 9, 25):
|
||||||
|
self.assertFalse(bool(tf_a == tf_e))
|
||||||
|
self.assertTrue(bool(tf_a != tf_e))
|
||||||
|
self.assertNotAllEqual(tf_a, tf_e)
|
||||||
|
else:
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
bool(tf_a != tf_e)
|
bool(tf_a != tf_e)
|
||||||
|
|
||||||
@ -313,7 +318,9 @@ class TFETest(test_util.TensorFlowTestCase):
|
|||||||
bool(np_a == np_c)
|
bool(np_a == np_c)
|
||||||
self.assertAllEqual(np_a == np_c, [[True, True], [True, True]])
|
self.assertAllEqual(np_a == np_c, [[True, True], [True, True]])
|
||||||
self.assertAllEqual(np_a == np_d, [[True, False], [True, False]])
|
self.assertAllEqual(np_a == np_d, [[True, False], [True, False]])
|
||||||
bool(np_a != np_e)
|
self.assertFalse(bool(np_a == np_e))
|
||||||
|
self.assertTrue(bool(np_a != np_e))
|
||||||
|
self.assertNotAllEqual(np_a, np_e)
|
||||||
finally:
|
finally:
|
||||||
if default:
|
if default:
|
||||||
ops.enable_tensor_equality()
|
ops.enable_tensor_equality()
|
||||||
|
@ -22,6 +22,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -255,7 +256,6 @@ class BinaryOpTest(test.TestCase):
|
|||||||
var_x = variables.Variable(x)
|
var_x = variables.Variable(x)
|
||||||
var_y = variables.Variable(y)
|
var_y = variables.Variable(y)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
self.evaluate([var_x.initializer, var_y.initializer])
|
self.evaluate([var_x.initializer, var_y.initializer])
|
||||||
left_result = self.evaluate(var_x * y)
|
left_result = self.evaluate(var_x * y)
|
||||||
right_result = self.evaluate(x * var_y)
|
right_result = self.evaluate(x * var_y)
|
||||||
@ -933,7 +933,6 @@ class ComparisonOpTest(test.TestCase):
|
|||||||
self._testBCastByFunc(
|
self._testBCastByFunc(
|
||||||
np.not_equal, math_ops.not_equal, include_complex=True)
|
np.not_equal, math_ops.not_equal, include_complex=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testShapeMismatch(self):
|
def testShapeMismatch(self):
|
||||||
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
|
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
|
||||||
funcs = [
|
funcs = [
|
||||||
@ -944,8 +943,9 @@ class ComparisonOpTest(test.TestCase):
|
|||||||
y = np.arange(0, 10).reshape([5, 2])
|
y = np.arange(0, 10).reshape([5, 2])
|
||||||
for t in dtypes:
|
for t in dtypes:
|
||||||
for f in funcs:
|
for f in funcs:
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, lambda e: "Dimensions must" in str(e)):
|
(ValueError, errors.InvalidArgumentError),
|
||||||
|
"Incompatible shapes|Dimensions must be equal"):
|
||||||
f(x.astype(t), y.astype(t))
|
f(x.astype(t), y.astype(t))
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import numpy as np
|
|||||||
from tensorflow.python.compat import compat
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -201,7 +202,6 @@ class ComparisonOpTest(test.TestCase):
|
|||||||
self._testBCastByFunc(
|
self._testBCastByFunc(
|
||||||
np.not_equal, math_ops.not_equal, include_complex=True)
|
np.not_equal, math_ops.not_equal, include_complex=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testShapeMismatch(self):
|
def testShapeMismatch(self):
|
||||||
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
|
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
|
||||||
funcs = [
|
funcs = [
|
||||||
@ -212,8 +212,9 @@ class ComparisonOpTest(test.TestCase):
|
|||||||
y = np.arange(0, 10).reshape([5, 2])
|
y = np.arange(0, 10).reshape([5, 2])
|
||||||
for t in dtypes:
|
for t in dtypes:
|
||||||
for f in funcs:
|
for f in funcs:
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, lambda e: "Dimensions must" in str(e)):
|
(ValueError, errors.InvalidArgumentError),
|
||||||
|
"Incompatible shapes|Dimensions must be equal"):
|
||||||
f(x.astype(t), y.astype(t))
|
f(x.astype(t), y.astype(t))
|
||||||
|
|
||||||
|
|
||||||
|
@ -310,7 +310,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
|
|||||||
static_func: Function that, if passed numpy ndarray versions of the two
|
static_func: Function that, if passed numpy ndarray versions of the two
|
||||||
inputs to the assertion, will return a Boolean ndarray with containing
|
inputs to the assertion, will return a Boolean ndarray with containing
|
||||||
True in all positions where the assertion PASSES.
|
True in all positions where the assertion PASSES.
|
||||||
i.e. lambda x,y: (x == y) for assert_equal()
|
i.e. np.equal for assert_equal()
|
||||||
x: Numeric `Tensor`.
|
x: Numeric `Tensor`.
|
||||||
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
|
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
|
||||||
data: The tensors to print out if the condition is False. Defaults to
|
data: The tensors to print out if the condition is False. Defaults to
|
||||||
@ -366,7 +366,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
|
|||||||
x_static = tensor_util.constant_value(x)
|
x_static = tensor_util.constant_value(x)
|
||||||
y_static = tensor_util.constant_value(y)
|
y_static = tensor_util.constant_value(y)
|
||||||
if x_static is not None and y_static is not None:
|
if x_static is not None and y_static is not None:
|
||||||
condition_static = static_func(x_static, y_static).all()
|
condition_static = np.all(static_func(x_static, y_static))
|
||||||
_assert_static(condition_static, data)
|
_assert_static(condition_static, data)
|
||||||
return control_flow_ops.Assert(condition, data, summarize=summarize)
|
return control_flow_ops.Assert(condition, data, summarize=summarize)
|
||||||
|
|
||||||
@ -654,9 +654,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # p
|
|||||||
# Short-circuit if x and y are the same tensor.
|
# Short-circuit if x and y are the same tensor.
|
||||||
if x is y:
|
if x is y:
|
||||||
return None if context.executing_eagerly() else control_flow_ops.no_op()
|
return None if context.executing_eagerly() else control_flow_ops.no_op()
|
||||||
return _binary_assert('==', 'assert_equal', math_ops.equal,
|
return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
|
||||||
lambda x, y: (x == y),
|
data, summarize, message, name)
|
||||||
x, y, data, summarize, message, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export('debugging.assert_none_equal', v1=[])
|
@tf_export('debugging.assert_none_equal', v1=[])
|
||||||
@ -703,8 +702,7 @@ def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
|
|||||||
def assert_none_equal(
|
def assert_none_equal(
|
||||||
x, y, data=None, summarize=None, message=None, name=None):
|
x, y, data=None, summarize=None, message=None, name=None):
|
||||||
return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
|
return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
|
||||||
lambda x, y: (x != y), x, y, data, summarize, message,
|
np.not_equal, x, y, data, summarize, message, name)
|
||||||
name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export('debugging.assert_near', v1=[])
|
@tf_export('debugging.assert_near', v1=[])
|
||||||
@ -877,8 +875,8 @@ def assert_less_v2(x, y, message=None, summarize=None, name=None):
|
|||||||
@tf_export(v1=['debugging.assert_less', 'assert_less'])
|
@tf_export(v1=['debugging.assert_less', 'assert_less'])
|
||||||
@_binary_assert_doc('<')
|
@_binary_assert_doc('<')
|
||||||
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
|
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
|
||||||
return _binary_assert('<', 'assert_less', math_ops.less, lambda x, y: (x < y),
|
return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
|
||||||
x, y, data, summarize, message, name)
|
summarize, message, name)
|
||||||
|
|
||||||
|
|
||||||
@tf_export('debugging.assert_less_equal', v1=[])
|
@tf_export('debugging.assert_less_equal', v1=[])
|
||||||
@ -922,8 +920,7 @@ def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
|
|||||||
@_binary_assert_doc('<=')
|
@_binary_assert_doc('<=')
|
||||||
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
|
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
|
||||||
return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
|
return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
|
||||||
lambda x, y: (x <= y), x, y, data, summarize, message,
|
np.less_equal, x, y, data, summarize, message, name)
|
||||||
name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
|
@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
|
||||||
@ -965,9 +962,8 @@ def assert_greater_v2(x, y, message=None, summarize=None, name=None):
|
|||||||
@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
|
@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
|
||||||
@_binary_assert_doc('>')
|
@_binary_assert_doc('>')
|
||||||
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
|
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
|
||||||
return _binary_assert('>', 'assert_greater', math_ops.greater,
|
return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
|
||||||
lambda x, y: (x > y),
|
y, data, summarize, message, name)
|
||||||
x, y, data, summarize, message, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export('debugging.assert_greater_equal', v1=[])
|
@tf_export('debugging.assert_greater_equal', v1=[])
|
||||||
@ -1013,8 +1009,7 @@ def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
|
|||||||
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
|
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
|
||||||
name=None):
|
name=None):
|
||||||
return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
|
return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
|
||||||
lambda x, y: (x >= y), x, y, data, summarize, message,
|
np.greater_equal, x, y, data, summarize, message, name)
|
||||||
name)
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_rank_condition(
|
def _assert_rank_condition(
|
||||||
@ -1026,8 +1021,8 @@ def _assert_rank_condition(
|
|||||||
rank: Scalar `Tensor`.
|
rank: Scalar `Tensor`.
|
||||||
static_condition: A python function that takes `[actual_rank, given_rank]`
|
static_condition: A python function that takes `[actual_rank, given_rank]`
|
||||||
and returns `True` if the condition is satisfied, `False` otherwise.
|
and returns `True` if the condition is satisfied, `False` otherwise.
|
||||||
dynamic_condition: An `op` that takes [actual_rank, given_rank]
|
dynamic_condition: An `op` that takes [actual_rank, given_rank] and return
|
||||||
and return `True` if the condition is satisfied, `False` otherwise.
|
`True` if the condition is satisfied, `False` otherwise.
|
||||||
data: The tensors to print out if the condition is false. Defaults to
|
data: The tensors to print out if the condition is false. Defaults to
|
||||||
error message and first few entries of `x`.
|
error message and first few entries of `x`.
|
||||||
summarize: Print this many entries of each tensor.
|
summarize: Print this many entries of each tensor.
|
||||||
@ -2159,4 +2154,3 @@ def ensure_shape(x, shape, name=None):
|
|||||||
def _ensure_shape_grad(op, grad):
|
def _ensure_shape_grad(op, grad):
|
||||||
del op # Unused.
|
del op # Unused.
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
@ -1270,11 +1270,64 @@ ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
|
|||||||
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
|
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("math.equal", "equal")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
|
def equal(x, y, name=None):
|
||||||
|
"""Returns the truth value of (x == y) element-wise.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = tf.constant([2, 4])
|
||||||
|
y = tf.constant(2)
|
||||||
|
tf.math.equal(x, y) ==> array([True, False])
|
||||||
|
|
||||||
|
x = tf.constant([2, 4])
|
||||||
|
y = tf.constant([2, 4])
|
||||||
|
tf.math.equal(x, y) ==> array([True, True])
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: `Equal` supports broadcasting. More about broadcasting [here](
|
||||||
|
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
|
||||||
|
y: A `Tensor` or `SparseTensor` or `IndexedSlices`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` of type bool with the same size as that of x or y.
|
||||||
|
"""
|
||||||
|
return gen_math_ops.equal(x, y, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("math.not_equal", "not_equal")
|
||||||
|
@dispatch.add_dispatch_support
|
||||||
|
def not_equal(x, y, name=None):
|
||||||
|
"""Returns the truth value of (x != y) element-wise.
|
||||||
|
|
||||||
|
**NOTE**: `NotEqual` supports broadcasting. More about broadcasting [here](
|
||||||
|
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
|
||||||
|
y: A `Tensor` or `SparseTensor` or `IndexedSlices`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` of type bool with the same size as that of x or y.
|
||||||
|
"""
|
||||||
|
return gen_math_ops.not_equal(x, y, name=name)
|
||||||
|
|
||||||
|
|
||||||
def tensor_equals(self, other):
|
def tensor_equals(self, other):
|
||||||
"""Compares two tensors element-wise for equality."""
|
"""Compares two tensors element-wise for equality."""
|
||||||
g = getattr(self, "graph", None)
|
g = getattr(self, "graph", None)
|
||||||
if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
|
if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
|
||||||
(g is None or g._building_function)): # pylint: disable=protected-access
|
(g is None or g._building_function)): # pylint: disable=protected-access
|
||||||
|
if fwd_compat.forward_compatible(2019, 9, 25):
|
||||||
|
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
|
||||||
|
else:
|
||||||
return gen_math_ops.equal(self, other)
|
return gen_math_ops.equal(self, other)
|
||||||
else:
|
else:
|
||||||
# In legacy graph mode, tensor equality is object equality
|
# In legacy graph mode, tensor equality is object equality
|
||||||
@ -1284,6 +1337,9 @@ def tensor_equals(self, other):
|
|||||||
def tensor_not_equals(self, other):
|
def tensor_not_equals(self, other):
|
||||||
"""Compares two tensors element-wise for equality."""
|
"""Compares two tensors element-wise for equality."""
|
||||||
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
|
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
|
||||||
|
if fwd_compat.forward_compatible(2019, 9, 25):
|
||||||
|
return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
|
||||||
|
else:
|
||||||
return gen_math_ops.not_equal(self, other)
|
return gen_math_ops.not_equal(self, other)
|
||||||
else:
|
else:
|
||||||
# In legacy graph mode, tensor equality is object equality
|
# In legacy graph mode, tensor equality is object equality
|
||||||
|
@ -2397,7 +2397,6 @@ def _convert_cast(pfor_input):
|
|||||||
@RegisterPForWithArgs("Div", math_ops.div)
|
@RegisterPForWithArgs("Div", math_ops.div)
|
||||||
@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
|
@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
|
||||||
@RegisterPForWithArgs("Elu", nn_ops.elu)
|
@RegisterPForWithArgs("Elu", nn_ops.elu)
|
||||||
@RegisterPForWithArgs("Equal", math_ops.equal)
|
|
||||||
@RegisterPForWithArgs("Erf", math_ops.erf)
|
@RegisterPForWithArgs("Erf", math_ops.erf)
|
||||||
@RegisterPForWithArgs("Erfc", math_ops.erfc)
|
@RegisterPForWithArgs("Erfc", math_ops.erfc)
|
||||||
@RegisterPForWithArgs("Exp", math_ops.exp)
|
@RegisterPForWithArgs("Exp", math_ops.exp)
|
||||||
@ -2432,7 +2431,6 @@ def _convert_cast(pfor_input):
|
|||||||
@RegisterPForWithArgs("Mul", math_ops.multiply)
|
@RegisterPForWithArgs("Mul", math_ops.multiply)
|
||||||
@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
|
@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
|
||||||
@RegisterPForWithArgs("Neg", math_ops.negative)
|
@RegisterPForWithArgs("Neg", math_ops.negative)
|
||||||
@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
|
|
||||||
@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
|
@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
|
||||||
@RegisterPForWithArgs("Pow", math_ops.pow)
|
@RegisterPForWithArgs("Pow", math_ops.pow)
|
||||||
@RegisterPForWithArgs("Real", math_ops.real)
|
@RegisterPForWithArgs("Real", math_ops.real)
|
||||||
@ -2471,6 +2469,26 @@ def _convert_cwise(pfor_input, op_type, op_func):
|
|||||||
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
|
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
|
||||||
|
|
||||||
|
|
||||||
|
@RegisterPFor("Equal")
|
||||||
|
def _convert_equal(pfor_input):
|
||||||
|
pfor_input.expanddim_inputs_for_broadcast()
|
||||||
|
x = pfor_input.input(0)[0]
|
||||||
|
y = pfor_input.input(1)[0]
|
||||||
|
incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
|
||||||
|
assert incompatible_shape_error
|
||||||
|
return wrap(math_ops.equal(x, y), True)
|
||||||
|
|
||||||
|
|
||||||
|
@RegisterPFor("NotEqual")
|
||||||
|
def _convert_not_equal(pfor_input):
|
||||||
|
pfor_input.expanddim_inputs_for_broadcast()
|
||||||
|
x = pfor_input.input(0)[0]
|
||||||
|
y = pfor_input.input(1)[0]
|
||||||
|
incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
|
||||||
|
assert incompatible_shape_error
|
||||||
|
return wrap(math_ops.not_equal(x, y), True)
|
||||||
|
|
||||||
|
|
||||||
@RegisterPFor("ApproximateEqual")
|
@RegisterPFor("ApproximateEqual")
|
||||||
def _convert_approximate_equal(pfor_input):
|
def _convert_approximate_equal(pfor_input):
|
||||||
pfor_input.expanddim_inputs_for_broadcast()
|
pfor_input.expanddim_inputs_for_broadcast()
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.core.framework import attr_value_pb2
|
|||||||
from tensorflow.core.framework import variable_pb2
|
from tensorflow.core.framework import variable_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
from tensorflow.python import _pywrap_utils
|
from tensorflow.python import _pywrap_utils
|
||||||
|
from tensorflow.python.compat import compat as fwd_compat
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -35,8 +36,8 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_array_ops
|
from tensorflow.python.ops import gen_array_ops
|
||||||
from tensorflow.python.ops import gen_math_ops
|
|
||||||
from tensorflow.python.ops import gen_state_ops
|
from tensorflow.python.ops import gen_state_ops
|
||||||
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -1092,6 +1093,9 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
"""Compares two variables element-wise for equality."""
|
"""Compares two variables element-wise for equality."""
|
||||||
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
|
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
|
||||||
|
if fwd_compat.forward_compatible(2019, 9, 25):
|
||||||
|
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
|
||||||
|
else:
|
||||||
return gen_math_ops.equal(self, other)
|
return gen_math_ops.equal(self, other)
|
||||||
else:
|
else:
|
||||||
# In legacy graph mode, tensor equality is object equality
|
# In legacy graph mode, tensor equality is object equality
|
||||||
@ -1101,6 +1105,10 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
|||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
"""Compares two variables element-wise for equality."""
|
"""Compares two variables element-wise for equality."""
|
||||||
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
|
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
|
||||||
|
if fwd_compat.forward_compatible(2019, 9, 25):
|
||||||
|
return gen_math_ops.not_equal(
|
||||||
|
self, other, incompatible_shape_error=False)
|
||||||
|
else:
|
||||||
return gen_math_ops.not_equal(self, other)
|
return gen_math_ops.not_equal(self, other)
|
||||||
else:
|
else:
|
||||||
# In legacy graph mode, tensor equality is object equality
|
# In legacy graph mode, tensor equality is object equality
|
||||||
|
@ -1174,7 +1174,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Equal"
|
name: "Equal"
|
||||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Erf"
|
name: "Erf"
|
||||||
@ -2398,7 +2398,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "NotEqual"
|
name: "NotEqual"
|
||||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "NthElement"
|
name: "NthElement"
|
||||||
|
@ -1174,7 +1174,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Equal"
|
name: "Equal"
|
||||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Erf"
|
name: "Erf"
|
||||||
@ -2398,7 +2398,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "NotEqual"
|
name: "NotEqual"
|
||||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "NthElement"
|
name: "NthElement"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user