Add incompatible_shape_error attribute to equal op

When tensor equality is enabled, if there is an incompatible shape we
currently throw and exception. Ideally we'd like to return False when
calling __eq__ and True when calling __ne__. We thus modify the Equal
and NotEqual ops to return a boolean upon a shape incompatibility. Due
to this change the shape inference logic needs to be changed to either
return a scalar bool if the shapes are incompatible, or else return an
unknown shape to allow for either a boolean Tensor or scalar to be
returned.

Note the behavior of tf.math.equal & tf.math.not_equal is unchanged as
they both use optimistic shape inference logic when dealing with unknown
dimensions which allows for more efficient graphs rather than inserting
Rank operations.

This distinction between __eq__ & tf.math.equal is also found in numpy
and as a result the tf.debugging.assert_equal and
tf.debugging.assert_none_equal APIs needed to be change to utilize the
numpy operations.

PiperOrigin-RevId: 267466043
This commit is contained in:
Gaurav Jain 2019-09-05 15:15:06 -07:00 committed by TensorFlower Gardener
parent 792abd2eaf
commit e0e1efbe08
17 changed files with 267 additions and 112 deletions

View File

@ -1,9 +1,4 @@
op { op {
graph_op_name: "Equal" graph_op_name: "Equal"
endpoint { visibility: HIDDEN
name: "math.equal"
}
endpoint {
name: "equal"
}
} }

View File

@ -1,9 +1,4 @@
op { op {
graph_op_name: "NotEqual" graph_op_name: "NotEqual"
endpoint { visibility: HIDDEN
name: "math.not_equal"
}
endpoint {
name: "not_equal"
}
} }

View File

@ -366,7 +366,8 @@ Status EinsumShape(shape_inference::InferenceContext* c) {
output_bcast_shape = input_bcast_shapes[0]; output_bcast_shape = input_bcast_shapes[0];
} else if (input_bcast_shapes.size() == 2) { } else if (input_bcast_shapes.size() == 2) {
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, input_bcast_shapes[0], input_bcast_shapes[1], &output_bcast_shape)); c, input_bcast_shapes[0], input_bcast_shapes[1], true,
&output_bcast_shape));
} }
bool output_has_ellipsis = false; bool output_has_ellipsis = false;
@ -441,7 +442,7 @@ Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape)); TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, a_batch_shape, b_batch_shape, &output_batch_shape)); c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
ShapeHandle output_shape; ShapeHandle output_shape;
TF_RETURN_IF_ERROR(c->Concatenate( TF_RETURN_IF_ERROR(c->Concatenate(
@ -1633,6 +1634,7 @@ Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
ShapeHandle shape_x, ShapeHandle shape_x,
ShapeHandle shape_y, ShapeHandle shape_y,
bool incompatible_shape_error,
ShapeHandle* out) { ShapeHandle* out) {
CHECK_NOTNULL(out); CHECK_NOTNULL(out);
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
@ -1666,8 +1668,16 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
// or the same as the known dim. // or the same as the known dim.
// - If either dimension is 1, the other dimension is the output. // - If either dimension is 1, the other dimension is the output.
if (c->Value(dim_x) > 1) { if (c->Value(dim_x) > 1) {
if (!incompatible_shape_error) {
*out = c->UnknownShape();
return Status::OK();
}
dims.push_back(dim_x); dims.push_back(dim_x);
} else if (c->Value(dim_y) > 1) { } else if (c->Value(dim_y) > 1) {
if (!incompatible_shape_error) {
*out = c->UnknownShape();
return Status::OK();
}
dims.push_back(dim_y); dims.push_back(dim_y);
} else if (c->Value(dim_x) == 1) { } else if (c->Value(dim_x) == 1) {
dims.push_back(dim_y); dims.push_back(dim_y);
@ -1676,6 +1686,10 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
} else if (dim_y.SameHandle(dim_x)) { } else if (dim_y.SameHandle(dim_x)) {
dims.push_back(dim_x); dims.push_back(dim_x);
} else { } else {
if (!incompatible_shape_error) {
*out = c->UnknownShape();
return Status::OK();
}
dims.push_back(c->UnknownDim()); dims.push_back(c->UnknownDim());
} }
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) { } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
@ -1689,7 +1703,14 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
} }
} else { } else {
DimensionHandle dim; DimensionHandle dim;
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim)); Status s = c->Merge(dim_x, dim_y, &dim);
if (!s.ok()) {
if (!incompatible_shape_error) {
*out = c->MakeShape({});
return Status::OK();
}
return s;
}
dims.push_back(dim); dims.push_back(dim);
} }
} }

View File

@ -306,6 +306,7 @@ Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
ShapeHandle shape_x, ShapeHandle shape_x,
ShapeHandle shape_y, ShapeHandle shape_y,
bool incompatible_shape_error,
ShapeHandle* out); ShapeHandle* out);
// Shape function for binary operators that broadcast their inputs // Shape function for binary operators that broadcast their inputs
@ -313,8 +314,8 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
int output_index) { int output_index) {
ShapeHandle out; ShapeHandle out;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out)); c, c->input(0), c->input(1), true, &out));
c->set_output(output_index, out); c->set_output(output_index, out);
return Status::OK(); return Status::OK();
} }

View File

@ -57,11 +57,23 @@ BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
in1(ctx->input(1)), in1(ctx->input(1)),
bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) { bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) {
if (!bcast.IsValid()) { if (!bcast.IsValid()) {
bool incompatible_shape_error;
bool has_attr =
TryGetNodeAttr(ctx->op_kernel().def(), "incompatible_shape_error",
&(incompatible_shape_error));
if (has_attr && !incompatible_shape_error) {
const string& op = ctx->op_kernel().type_string();
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
result = (op == "NotEqual");
return;
}
ctx->SetStatus(errors::InvalidArgument( ctx->SetStatus(errors::InvalidArgument(
"Incompatible shapes: ", in0.shape().DebugString(), " vs. ", "Incompatible shapes: ", in0.shape().DebugString(), " vs. ",
in1.shape().DebugString())); in1.shape().DebugString()));
return; return;
} }
const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
out_num_elements = output_shape.num_elements(); out_num_elements = output_shape.num_elements();
in0_num_elements = in0.NumElements(); in0_num_elements = in0.NumElements();

View File

@ -26,13 +26,13 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_sycl_common.h" #include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
#endif #endif
#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/bcast.h" #include "tensorflow/core/util/bcast.h"
@ -56,7 +56,7 @@ class BinaryOpShared : public OpKernel {
// in-place computation. // in-place computation.
// Caller must check ctx->status() upon return for non-ok status. // Caller must check ctx->status() upon return for non-ok status.
// If ctx->status().ok() is true, then out is guaranteed to be allocated. // If ctx->status().ok() is true, then out is guaranteed to be allocated.
BinaryOpState(OpKernelContext* ctx); explicit BinaryOpState(OpKernelContext* ctx);
const Tensor& in0; const Tensor& in0;
const Tensor& in1; const Tensor& in1;
@ -69,6 +69,7 @@ class BinaryOpShared : public OpKernel {
int64 in1_num_elements; int64 in1_num_elements;
int ndims; int ndims;
bool result;
}; };
void SetUnimplementedError(OpKernelContext* ctx); void SetUnimplementedError(OpKernelContext* ctx);
@ -91,16 +92,29 @@ class BinaryOp : public BinaryOpShared {
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
// 'state': Shared helper not dependent on T to reduce code size // 'state': Shared helper not dependent on T to reduce code size
BinaryOpState state(ctx); BinaryOpState state(ctx);
if (!ctx->status().ok()) return; auto& bcast = state.bcast;
const Device& eigen_device = ctx->eigen_device<Device>();
Tensor* out = state.out; Tensor* out = state.out;
BCast* bcast = &state.bcast; if (!bcast.IsValid()) {
if (ctx->status().ok()) {
if (state.result) {
functor::SetOneFunctor<Device, bool>()(eigen_device,
out->flat<bool>());
} else {
functor::SetZeroFunctor<Device, bool>()(eigen_device,
out->flat<bool>());
}
}
return;
}
auto& in0 = state.in0; auto& in0 = state.in0;
auto& in1 = state.in1; auto& in1 = state.in1;
if (state.out_num_elements == 0) { if (state.out_num_elements == 0) {
return; return;
} }
const int ndims = state.ndims; const int ndims = state.ndims;
const Device& eigen_device = ctx->eigen_device<Device>();
bool error = false; bool error = false;
bool* const error_ptr = Functor::has_errors ? &error : nullptr; bool* const error_ptr = Functor::has_errors ? &error : nullptr;
if (ndims <= 1) { if (ndims <= 1) {
@ -122,32 +136,32 @@ class BinaryOp : public BinaryOpShared {
} }
} else if (ndims == 2) { } else if (ndims == 2) {
functor::BinaryFunctor<Device, Functor, 2>().BCast( functor::BinaryFunctor<Device, Functor, 2>().BCast(
eigen_device, out->shaped<Tout, 2>(bcast->result_shape()), eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
in0.template shaped<Tin, 2>(bcast->x_reshape()), in0.template shaped<Tin, 2>(bcast.x_reshape()),
BCast::ToIndexArray<2>(bcast->x_bcast()), BCast::ToIndexArray<2>(bcast.x_bcast()),
in1.template shaped<Tin, 2>(bcast->y_reshape()), in1.template shaped<Tin, 2>(bcast.y_reshape()),
BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr); BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
} else if (ndims == 3) { } else if (ndims == 3) {
functor::BinaryFunctor<Device, Functor, 3>().BCast( functor::BinaryFunctor<Device, Functor, 3>().BCast(
eigen_device, out->shaped<Tout, 3>(bcast->result_shape()), eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
in0.template shaped<Tin, 3>(bcast->x_reshape()), in0.template shaped<Tin, 3>(bcast.x_reshape()),
BCast::ToIndexArray<3>(bcast->x_bcast()), BCast::ToIndexArray<3>(bcast.x_bcast()),
in1.template shaped<Tin, 3>(bcast->y_reshape()), in1.template shaped<Tin, 3>(bcast.y_reshape()),
BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr); BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
} else if (ndims == 4) { } else if (ndims == 4) {
functor::BinaryFunctor<Device, Functor, 4>().BCast( functor::BinaryFunctor<Device, Functor, 4>().BCast(
eigen_device, out->shaped<Tout, 4>(bcast->result_shape()), eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
in0.template shaped<Tin, 4>(bcast->x_reshape()), in0.template shaped<Tin, 4>(bcast.x_reshape()),
BCast::ToIndexArray<4>(bcast->x_bcast()), BCast::ToIndexArray<4>(bcast.x_bcast()),
in1.template shaped<Tin, 4>(bcast->y_reshape()), in1.template shaped<Tin, 4>(bcast.y_reshape()),
BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr); BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
} else if (ndims == 5) { } else if (ndims == 5) {
functor::BinaryFunctor<Device, Functor, 5>().BCast( functor::BinaryFunctor<Device, Functor, 5>().BCast(
eigen_device, out->shaped<Tout, 5>(bcast->result_shape()), eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
in0.template shaped<Tin, 5>(bcast->x_reshape()), in0.template shaped<Tin, 5>(bcast.x_reshape()),
BCast::ToIndexArray<5>(bcast->x_bcast()), BCast::ToIndexArray<5>(bcast.x_bcast()),
in1.template shaped<Tin, 5>(bcast->y_reshape()), in1.template shaped<Tin, 5>(bcast.y_reshape()),
BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr); BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
} else { } else {
SetUnimplementedError(ctx); SetUnimplementedError(ctx);
} }

View File

@ -700,7 +700,19 @@ REGISTER_OP("GreaterEqual").COMPARISON();
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \ "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
"int64, complex64, quint8, qint8, qint32, string, bool, " \ "int64, complex64, quint8, qint8, qint32, string, bool, " \
"complex128}") \ "complex128}") \
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Attr("incompatible_shape_error: bool = true") \
.SetShapeFn([](InferenceContext* c) { \
ShapeHandle x = c->input(0); \
ShapeHandle y = c->input(1); \
ShapeHandle output; \
bool incompatible_shape_error; \
TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \
&incompatible_shape_error)); \
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \
c, x, y, incompatible_shape_error, &output)); \
c->set_output(0, output); \
return Status::OK(); \
})
REGISTER_OP("Equal").EQUALITY_COMPARISON(); REGISTER_OP("Equal").EQUALITY_COMPARISON();
@ -877,10 +889,10 @@ REGISTER_OP("SelectV2")
ShapeHandle else_ = c->input(2); ShapeHandle else_ = c->input(2);
ShapeHandle other; ShapeHandle other;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other)); BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other));
ShapeHandle output; ShapeHandle output;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output)); BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output));
c->set_output(0, output); c->set_output(0, output);
return Status::OK(); return Status::OK();
}); });

View File

@ -110,28 +110,17 @@ TEST(MathOpsTest, Segment_ShapeFn) {
} }
TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) { TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
for (const auto* op_name : {"Add", "Complex", auto test_shapes = [&](ShapeInferenceTestOp& op,
"Div", "Equal", bool incompatible_shape_error) {
"Greater", "GreaterEqual",
"Igamma", "Igammac",
"Zeta", "Polygamma",
"Less", "LessEqual",
"LogicalAnd", "LogicalOr",
"Maximum", "Minimum",
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
"DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?"); INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?"); INFER_OK(op, "[1,2];?", "?");
INFER_OK(op, "?;[1,2]", "?"); INFER_OK(op, "?;[1,2]", "?");
INFER_OK(op, "[?];[1]", "[d0_0]"); INFER_OK(op, "[?];[1]", "[d0_0]");
INFER_OK(op, "[1];[?]", "[d1_0]"); INFER_OK(op, "[1];[?]", "[d1_0]");
INFER_OK(op, "[?];[2]", "[d1_0]"); INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?");
INFER_OK(op, "[2];[?]", "[d0_0]"); INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
INFER_OK(op, "[?];[?]", "[?]"); INFER_OK(op, "[?];[?]", incompatible_shape_error ? "[?]" : "?");
INFER_OK(op, "[];[?]", "[d1_0]"); INFER_OK(op, "[];[?]", "[d1_0]");
INFER_OK(op, "[?];[]", "[d0_0]"); INFER_OK(op, "[?];[]", "[d0_0]");
@ -144,7 +133,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
INFER_OK(op, "[1];[2]", "[d1_0]"); INFER_OK(op, "[1];[2]", "[d1_0]");
INFER_OK(op, "[2];[1]", "[d0_0]"); INFER_OK(op, "[2];[1]", "[d0_0]");
INFER_OK(op, "[2];[]", "[d0_0]"); INFER_OK(op, "[2];[]", "[d0_0]");
INFER_OK(op, "[2];[?]", "[d0_0]"); INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
INFER_OK(op, "[0];[0]", "[d0_0|d1_0]"); INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
INFER_OK(op, "[];[0]", "[d1_0]"); INFER_OK(op, "[];[0]", "[d1_0]");
@ -152,14 +141,46 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
INFER_OK(op, "[0];[1]", "[d0_0]"); INFER_OK(op, "[0];[1]", "[d0_0]");
INFER_OK(op, "[0];[]", "[d0_0]"); INFER_OK(op, "[0];[]", "[d0_0]");
INFER_OK(op, "[2];[?,?]", "[d1_0,d0_0]"); INFER_OK(op, "[2];[?,?]", incompatible_shape_error ? "[d1_0,d0_0]" : "?");
INFER_OK(op, "[2,2];[?,?,?]", "[d1_0,d0_0,d0_1]"); INFER_OK(op, "[2,2];[?,?,?]",
incompatible_shape_error ? "[d1_0,d0_0,d0_1]" : "?");
// Multiple dimension cases (same test cases, switching x and y). // Multiple dimension cases (same test cases, switching x and y).
INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]", INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
"[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"); incompatible_shape_error ? "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"
: "?");
INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]", INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
"[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"); incompatible_shape_error ? "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"
: "?");
if (incompatible_shape_error) {
INFER_ERROR("Dimensions must be equal", op, "[2];[3]");
} else {
INFER_OK(op, "[2];[3]", "[]");
}
};
for (string op_name : {"Add", "Complex",
"Div", "Equal",
"Greater", "GreaterEqual",
"Igamma", "Igammac",
"Zeta", "Polygamma",
"Less", "LessEqual",
"LogicalAnd", "LogicalOr",
"Maximum", "Minimum",
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
"DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
AddNodeAttr("incompatible_shape_error", true, &op.node_def);
test_shapes(op, true);
if ((op_name == "Equal") || (op_name == "NotEqual")) {
ShapeInferenceTestOp op(op_name);
AddNodeAttr("incompatible_shape_error", false, &op.node_def);
test_shapes(op, false);
}
} }
} }

View File

@ -27,6 +27,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.compat import compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import core from tensorflow.python.eager import core
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -302,7 +303,11 @@ class TFETest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
bool(tf_a == tf_d) bool(tf_a == tf_d)
self.assertAllEqual(tf_a == tf_d, [[True, False], [True, False]]) self.assertAllEqual(tf_a == tf_d, [[True, False], [True, False]])
# TODO(b/120678848): If shapes do not match we should instead return False if compat.forward_compatible(2019, 9, 25):
self.assertFalse(bool(tf_a == tf_e))
self.assertTrue(bool(tf_a != tf_e))
self.assertNotAllEqual(tf_a, tf_e)
else:
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
bool(tf_a != tf_e) bool(tf_a != tf_e)
@ -313,7 +318,9 @@ class TFETest(test_util.TensorFlowTestCase):
bool(np_a == np_c) bool(np_a == np_c)
self.assertAllEqual(np_a == np_c, [[True, True], [True, True]]) self.assertAllEqual(np_a == np_c, [[True, True], [True, True]])
self.assertAllEqual(np_a == np_d, [[True, False], [True, False]]) self.assertAllEqual(np_a == np_d, [[True, False], [True, False]])
bool(np_a != np_e) self.assertFalse(bool(np_a == np_e))
self.assertTrue(bool(np_a != np_e))
self.assertNotAllEqual(np_a, np_e)
finally: finally:
if default: if default:
ops.enable_tensor_equality() ops.enable_tensor_equality()

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
@ -255,7 +256,6 @@ class BinaryOpTest(test.TestCase):
var_x = variables.Variable(x) var_x = variables.Variable(x)
var_y = variables.Variable(y) var_y = variables.Variable(y)
with self.cached_session() as sess:
self.evaluate([var_x.initializer, var_y.initializer]) self.evaluate([var_x.initializer, var_y.initializer])
left_result = self.evaluate(var_x * y) left_result = self.evaluate(var_x * y)
right_result = self.evaluate(x * var_y) right_result = self.evaluate(x * var_y)
@ -933,7 +933,6 @@ class ComparisonOpTest(test.TestCase):
self._testBCastByFunc( self._testBCastByFunc(
np.not_equal, math_ops.not_equal, include_complex=True) np.not_equal, math_ops.not_equal, include_complex=True)
@test_util.run_deprecated_v1
def testShapeMismatch(self): def testShapeMismatch(self):
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64] dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
funcs = [ funcs = [
@ -944,8 +943,9 @@ class ComparisonOpTest(test.TestCase):
y = np.arange(0, 10).reshape([5, 2]) y = np.arange(0, 10).reshape([5, 2])
for t in dtypes: for t in dtypes:
for f in funcs: for f in funcs:
with self.assertRaisesWithPredicateMatch( with self.assertRaisesRegexp(
ValueError, lambda e: "Dimensions must" in str(e)): (ValueError, errors.InvalidArgumentError),
"Incompatible shapes|Dimensions must be equal"):
f(x.astype(t), y.astype(t)) f(x.astype(t), y.astype(t))

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.compat import compat from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -201,7 +202,6 @@ class ComparisonOpTest(test.TestCase):
self._testBCastByFunc( self._testBCastByFunc(
np.not_equal, math_ops.not_equal, include_complex=True) np.not_equal, math_ops.not_equal, include_complex=True)
@test_util.run_deprecated_v1
def testShapeMismatch(self): def testShapeMismatch(self):
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64] dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
funcs = [ funcs = [
@ -212,8 +212,9 @@ class ComparisonOpTest(test.TestCase):
y = np.arange(0, 10).reshape([5, 2]) y = np.arange(0, 10).reshape([5, 2])
for t in dtypes: for t in dtypes:
for f in funcs: for f in funcs:
with self.assertRaisesWithPredicateMatch( with self.assertRaisesRegexp(
ValueError, lambda e: "Dimensions must" in str(e)): (ValueError, errors.InvalidArgumentError),
"Incompatible shapes|Dimensions must be equal"):
f(x.astype(t), y.astype(t)) f(x.astype(t), y.astype(t))

View File

@ -310,7 +310,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
static_func: Function that, if passed numpy ndarray versions of the two static_func: Function that, if passed numpy ndarray versions of the two
inputs to the assertion, will return a Boolean ndarray with containing inputs to the assertion, will return a Boolean ndarray with containing
True in all positions where the assertion PASSES. True in all positions where the assertion PASSES.
i.e. lambda x,y: (x == y) for assert_equal() i.e. np.equal for assert_equal()
x: Numeric `Tensor`. x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to data: The tensors to print out if the condition is False. Defaults to
@ -366,7 +366,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
x_static = tensor_util.constant_value(x) x_static = tensor_util.constant_value(x)
y_static = tensor_util.constant_value(y) y_static = tensor_util.constant_value(y)
if x_static is not None and y_static is not None: if x_static is not None and y_static is not None:
condition_static = static_func(x_static, y_static).all() condition_static = np.all(static_func(x_static, y_static))
_assert_static(condition_static, data) _assert_static(condition_static, data)
return control_flow_ops.Assert(condition, data, summarize=summarize) return control_flow_ops.Assert(condition, data, summarize=summarize)
@ -654,9 +654,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # p
# Short-circuit if x and y are the same tensor. # Short-circuit if x and y are the same tensor.
if x is y: if x is y:
return None if context.executing_eagerly() else control_flow_ops.no_op() return None if context.executing_eagerly() else control_flow_ops.no_op()
return _binary_assert('==', 'assert_equal', math_ops.equal, return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
lambda x, y: (x == y), data, summarize, message, name)
x, y, data, summarize, message, name)
@tf_export('debugging.assert_none_equal', v1=[]) @tf_export('debugging.assert_none_equal', v1=[])
@ -703,8 +702,7 @@ def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
def assert_none_equal( def assert_none_equal(
x, y, data=None, summarize=None, message=None, name=None): x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal, return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
lambda x, y: (x != y), x, y, data, summarize, message, np.not_equal, x, y, data, summarize, message, name)
name)
@tf_export('debugging.assert_near', v1=[]) @tf_export('debugging.assert_near', v1=[])
@ -877,8 +875,8 @@ def assert_less_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_less', 'assert_less']) @tf_export(v1=['debugging.assert_less', 'assert_less'])
@_binary_assert_doc('<') @_binary_assert_doc('<')
def assert_less(x, y, data=None, summarize=None, message=None, name=None): def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<', 'assert_less', math_ops.less, lambda x, y: (x < y), return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
x, y, data, summarize, message, name) summarize, message, name)
@tf_export('debugging.assert_less_equal', v1=[]) @tf_export('debugging.assert_less_equal', v1=[])
@ -922,8 +920,7 @@ def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
@_binary_assert_doc('<=') @_binary_assert_doc('<=')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal, return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
lambda x, y: (x <= y), x, y, data, summarize, message, np.less_equal, x, y, data, summarize, message, name)
name)
@tf_export('debugging.assert_greater', 'assert_greater', v1=[]) @tf_export('debugging.assert_greater', 'assert_greater', v1=[])
@ -965,9 +962,8 @@ def assert_greater_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_greater', 'assert_greater']) @tf_export(v1=['debugging.assert_greater', 'assert_greater'])
@_binary_assert_doc('>') @_binary_assert_doc('>')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
return _binary_assert('>', 'assert_greater', math_ops.greater, return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
lambda x, y: (x > y), y, data, summarize, message, name)
x, y, data, summarize, message, name)
@tf_export('debugging.assert_greater_equal', v1=[]) @tf_export('debugging.assert_greater_equal', v1=[])
@ -1013,8 +1009,7 @@ def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
def assert_greater_equal(x, y, data=None, summarize=None, message=None, def assert_greater_equal(x, y, data=None, summarize=None, message=None,
name=None): name=None):
return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal, return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
lambda x, y: (x >= y), x, y, data, summarize, message, np.greater_equal, x, y, data, summarize, message, name)
name)
def _assert_rank_condition( def _assert_rank_condition(
@ -1026,8 +1021,8 @@ def _assert_rank_condition(
rank: Scalar `Tensor`. rank: Scalar `Tensor`.
static_condition: A python function that takes `[actual_rank, given_rank]` static_condition: A python function that takes `[actual_rank, given_rank]`
and returns `True` if the condition is satisfied, `False` otherwise. and returns `True` if the condition is satisfied, `False` otherwise.
dynamic_condition: An `op` that takes [actual_rank, given_rank] dynamic_condition: An `op` that takes [actual_rank, given_rank] and return
and return `True` if the condition is satisfied, `False` otherwise. `True` if the condition is satisfied, `False` otherwise.
data: The tensors to print out if the condition is false. Defaults to data: The tensors to print out if the condition is false. Defaults to
error message and first few entries of `x`. error message and first few entries of `x`.
summarize: Print this many entries of each tensor. summarize: Print this many entries of each tensor.
@ -2159,4 +2154,3 @@ def ensure_shape(x, shape, name=None):
def _ensure_shape_grad(op, grad): def _ensure_shape_grad(op, grad):
del op # Unused. del op # Unused.
return grad return grad

View File

@ -1270,11 +1270,64 @@ ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal) ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
@tf_export("math.equal", "equal")
@dispatch.add_dispatch_support
def equal(x, y, name=None):
"""Returns the truth value of (x == y) element-wise.
Usage:
```python
x = tf.constant([2, 4])
y = tf.constant(2)
tf.math.equal(x, y) ==> array([True, False])
x = tf.constant([2, 4])
y = tf.constant([2, 4])
tf.math.equal(x, y) ==> array([True, True])
```
**NOTE**: `Equal` supports broadcasting. More about broadcasting [here](
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
Args:
x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
y: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type bool with the same size as that of x or y.
"""
return gen_math_ops.equal(x, y, name=name)
@tf_export("math.not_equal", "not_equal")
@dispatch.add_dispatch_support
def not_equal(x, y, name=None):
"""Returns the truth value of (x != y) element-wise.
**NOTE**: `NotEqual` supports broadcasting. More about broadcasting [here](
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
Args:
x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
y: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type bool with the same size as that of x or y.
"""
return gen_math_ops.not_equal(x, y, name=name)
def tensor_equals(self, other): def tensor_equals(self, other):
"""Compares two tensors element-wise for equality.""" """Compares two tensors element-wise for equality."""
g = getattr(self, "graph", None) g = getattr(self, "graph", None)
if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
(g is None or g._building_function)): # pylint: disable=protected-access (g is None or g._building_function)): # pylint: disable=protected-access
if fwd_compat.forward_compatible(2019, 9, 25):
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
else:
return gen_math_ops.equal(self, other) return gen_math_ops.equal(self, other)
else: else:
# In legacy graph mode, tensor equality is object equality # In legacy graph mode, tensor equality is object equality
@ -1284,6 +1337,9 @@ def tensor_equals(self, other):
def tensor_not_equals(self, other): def tensor_not_equals(self, other):
"""Compares two tensors element-wise for equality.""" """Compares two tensors element-wise for equality."""
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
if fwd_compat.forward_compatible(2019, 9, 25):
return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
else:
return gen_math_ops.not_equal(self, other) return gen_math_ops.not_equal(self, other)
else: else:
# In legacy graph mode, tensor equality is object equality # In legacy graph mode, tensor equality is object equality

View File

@ -2397,7 +2397,6 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Div", math_ops.div) @RegisterPForWithArgs("Div", math_ops.div)
@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan) @RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
@RegisterPForWithArgs("Elu", nn_ops.elu) @RegisterPForWithArgs("Elu", nn_ops.elu)
@RegisterPForWithArgs("Equal", math_ops.equal)
@RegisterPForWithArgs("Erf", math_ops.erf) @RegisterPForWithArgs("Erf", math_ops.erf)
@RegisterPForWithArgs("Erfc", math_ops.erfc) @RegisterPForWithArgs("Erfc", math_ops.erfc)
@RegisterPForWithArgs("Exp", math_ops.exp) @RegisterPForWithArgs("Exp", math_ops.exp)
@ -2432,7 +2431,6 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Mul", math_ops.multiply) @RegisterPForWithArgs("Mul", math_ops.multiply)
@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan) @RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
@RegisterPForWithArgs("Neg", math_ops.negative) @RegisterPForWithArgs("Neg", math_ops.negative)
@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
@RegisterPForWithArgs("Polygamma", math_ops.polygamma) @RegisterPForWithArgs("Polygamma", math_ops.polygamma)
@RegisterPForWithArgs("Pow", math_ops.pow) @RegisterPForWithArgs("Pow", math_ops.pow)
@RegisterPForWithArgs("Real", math_ops.real) @RegisterPForWithArgs("Real", math_ops.real)
@ -2471,6 +2469,26 @@ def _convert_cwise(pfor_input, op_type, op_func):
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
@RegisterPFor("Equal")
def _convert_equal(pfor_input):
pfor_input.expanddim_inputs_for_broadcast()
x = pfor_input.input(0)[0]
y = pfor_input.input(1)[0]
incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
assert incompatible_shape_error
return wrap(math_ops.equal(x, y), True)
@RegisterPFor("NotEqual")
def _convert_not_equal(pfor_input):
pfor_input.expanddim_inputs_for_broadcast()
x = pfor_input.input(0)[0]
y = pfor_input.input(1)[0]
incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
assert incompatible_shape_error
return wrap(math_ops.not_equal(x, y), True)
@RegisterPFor("ApproximateEqual") @RegisterPFor("ApproximateEqual")
def _convert_approximate_equal(pfor_input): def _convert_approximate_equal(pfor_input):
pfor_input.expanddim_inputs_for_broadcast() pfor_input.expanddim_inputs_for_broadcast()

View File

@ -28,6 +28,7 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework import variable_pb2
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python import _pywrap_utils from tensorflow.python import _pywrap_utils
from tensorflow.python.compat import compat as fwd_compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -35,8 +36,8 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -1092,6 +1093,9 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
def __eq__(self, other): def __eq__(self, other):
"""Compares two variables element-wise for equality.""" """Compares two variables element-wise for equality."""
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
if fwd_compat.forward_compatible(2019, 9, 25):
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
else:
return gen_math_ops.equal(self, other) return gen_math_ops.equal(self, other)
else: else:
# In legacy graph mode, tensor equality is object equality # In legacy graph mode, tensor equality is object equality
@ -1101,6 +1105,10 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
def __ne__(self, other): def __ne__(self, other):
"""Compares two variables element-wise for equality.""" """Compares two variables element-wise for equality."""
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
if fwd_compat.forward_compatible(2019, 9, 25):
return gen_math_ops.not_equal(
self, other, incompatible_shape_error=False)
else:
return gen_math_ops.not_equal(self, other) return gen_math_ops.not_equal(self, other)
else: else:
# In legacy graph mode, tensor equality is object equality # In legacy graph mode, tensor equality is object equality

View File

@ -1174,7 +1174,7 @@ tf_module {
} }
member_method { member_method {
name: "Equal" name: "Equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
} }
member_method { member_method {
name: "Erf" name: "Erf"
@ -2398,7 +2398,7 @@ tf_module {
} }
member_method { member_method {
name: "NotEqual" name: "NotEqual"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
} }
member_method { member_method {
name: "NthElement" name: "NthElement"

View File

@ -1174,7 +1174,7 @@ tf_module {
} }
member_method { member_method {
name: "Equal" name: "Equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
} }
member_method { member_method {
name: "Erf" name: "Erf"
@ -2398,7 +2398,7 @@ tf_module {
} }
member_method { member_method {
name: "NotEqual" name: "NotEqual"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
} }
member_method { member_method {
name: "NthElement" name: "NthElement"