Add incompatible_shape_error attribute to equal op
When tensor equality is enabled, if there is an incompatible shape we currently throw and exception. Ideally we'd like to return False when calling __eq__ and True when calling __ne__. We thus modify the Equal and NotEqual ops to return a boolean upon a shape incompatibility. Due to this change the shape inference logic needs to be changed to either return a scalar bool if the shapes are incompatible, or else return an unknown shape to allow for either a boolean Tensor or scalar to be returned. Note the behavior of tf.math.equal & tf.math.not_equal is unchanged as they both use optimistic shape inference logic when dealing with unknown dimensions which allows for more efficient graphs rather than inserting Rank operations. This distinction between __eq__ & tf.math.equal is also found in numpy and as a result the tf.debugging.assert_equal and tf.debugging.assert_none_equal APIs needed to be change to utilize the numpy operations. PiperOrigin-RevId: 267466043
This commit is contained in:
parent
792abd2eaf
commit
e0e1efbe08
tensorflow
core
api_def/python_api
framework
kernels
ops
python
eager
kernel_tests
ops
tools/api/golden
@ -1,9 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Equal"
|
||||
endpoint {
|
||||
name: "math.equal"
|
||||
}
|
||||
endpoint {
|
||||
name: "equal"
|
||||
}
|
||||
visibility: HIDDEN
|
||||
}
|
||||
|
@ -1,9 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "NotEqual"
|
||||
endpoint {
|
||||
name: "math.not_equal"
|
||||
}
|
||||
endpoint {
|
||||
name: "not_equal"
|
||||
}
|
||||
visibility: HIDDEN
|
||||
}
|
||||
|
@ -366,7 +366,8 @@ Status EinsumShape(shape_inference::InferenceContext* c) {
|
||||
output_bcast_shape = input_bcast_shapes[0];
|
||||
} else if (input_bcast_shapes.size() == 2) {
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, input_bcast_shapes[0], input_bcast_shapes[1], &output_bcast_shape));
|
||||
c, input_bcast_shapes[0], input_bcast_shapes[1], true,
|
||||
&output_bcast_shape));
|
||||
}
|
||||
|
||||
bool output_has_ellipsis = false;
|
||||
@ -441,7 +442,7 @@ Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
|
||||
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, a_batch_shape, b_batch_shape, &output_batch_shape));
|
||||
c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
|
||||
|
||||
ShapeHandle output_shape;
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(
|
||||
@ -1633,6 +1634,7 @@ Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
|
||||
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
ShapeHandle shape_x,
|
||||
ShapeHandle shape_y,
|
||||
bool incompatible_shape_error,
|
||||
ShapeHandle* out) {
|
||||
CHECK_NOTNULL(out);
|
||||
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
|
||||
@ -1666,8 +1668,16 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
// or the same as the known dim.
|
||||
// - If either dimension is 1, the other dimension is the output.
|
||||
if (c->Value(dim_x) > 1) {
|
||||
if (!incompatible_shape_error) {
|
||||
*out = c->UnknownShape();
|
||||
return Status::OK();
|
||||
}
|
||||
dims.push_back(dim_x);
|
||||
} else if (c->Value(dim_y) > 1) {
|
||||
if (!incompatible_shape_error) {
|
||||
*out = c->UnknownShape();
|
||||
return Status::OK();
|
||||
}
|
||||
dims.push_back(dim_y);
|
||||
} else if (c->Value(dim_x) == 1) {
|
||||
dims.push_back(dim_y);
|
||||
@ -1676,6 +1686,10 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
} else if (dim_y.SameHandle(dim_x)) {
|
||||
dims.push_back(dim_x);
|
||||
} else {
|
||||
if (!incompatible_shape_error) {
|
||||
*out = c->UnknownShape();
|
||||
return Status::OK();
|
||||
}
|
||||
dims.push_back(c->UnknownDim());
|
||||
}
|
||||
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
|
||||
@ -1689,7 +1703,14 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
}
|
||||
} else {
|
||||
DimensionHandle dim;
|
||||
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
|
||||
Status s = c->Merge(dim_x, dim_y, &dim);
|
||||
if (!s.ok()) {
|
||||
if (!incompatible_shape_error) {
|
||||
*out = c->MakeShape({});
|
||||
return Status::OK();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
dims.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
@ -306,6 +306,7 @@ Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
|
||||
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
ShapeHandle shape_x,
|
||||
ShapeHandle shape_y,
|
||||
bool incompatible_shape_error,
|
||||
ShapeHandle* out);
|
||||
|
||||
// Shape function for binary operators that broadcast their inputs
|
||||
@ -313,8 +314,8 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
|
||||
int output_index) {
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, c->input(0), c->input(1), true, &out));
|
||||
c->set_output(output_index, out);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -57,11 +57,23 @@ BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
|
||||
in1(ctx->input(1)),
|
||||
bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) {
|
||||
if (!bcast.IsValid()) {
|
||||
bool incompatible_shape_error;
|
||||
bool has_attr =
|
||||
TryGetNodeAttr(ctx->op_kernel().def(), "incompatible_shape_error",
|
||||
&(incompatible_shape_error));
|
||||
if (has_attr && !incompatible_shape_error) {
|
||||
const string& op = ctx->op_kernel().type_string();
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
|
||||
result = (op == "NotEqual");
|
||||
return;
|
||||
}
|
||||
|
||||
ctx->SetStatus(errors::InvalidArgument(
|
||||
"Incompatible shapes: ", in0.shape().DebugString(), " vs. ",
|
||||
in1.shape().DebugString()));
|
||||
return;
|
||||
}
|
||||
|
||||
const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
|
||||
out_num_elements = output_shape.num_elements();
|
||||
in0_num_elements = in0.NumElements();
|
||||
|
@ -26,13 +26,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
@ -56,7 +56,7 @@ class BinaryOpShared : public OpKernel {
|
||||
// in-place computation.
|
||||
// Caller must check ctx->status() upon return for non-ok status.
|
||||
// If ctx->status().ok() is true, then out is guaranteed to be allocated.
|
||||
BinaryOpState(OpKernelContext* ctx);
|
||||
explicit BinaryOpState(OpKernelContext* ctx);
|
||||
|
||||
const Tensor& in0;
|
||||
const Tensor& in1;
|
||||
@ -69,6 +69,7 @@ class BinaryOpShared : public OpKernel {
|
||||
int64 in1_num_elements;
|
||||
|
||||
int ndims;
|
||||
bool result;
|
||||
};
|
||||
|
||||
void SetUnimplementedError(OpKernelContext* ctx);
|
||||
@ -91,16 +92,29 @@ class BinaryOp : public BinaryOpShared {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// 'state': Shared helper not dependent on T to reduce code size
|
||||
BinaryOpState state(ctx);
|
||||
if (!ctx->status().ok()) return;
|
||||
auto& bcast = state.bcast;
|
||||
const Device& eigen_device = ctx->eigen_device<Device>();
|
||||
Tensor* out = state.out;
|
||||
BCast* bcast = &state.bcast;
|
||||
if (!bcast.IsValid()) {
|
||||
if (ctx->status().ok()) {
|
||||
if (state.result) {
|
||||
functor::SetOneFunctor<Device, bool>()(eigen_device,
|
||||
out->flat<bool>());
|
||||
} else {
|
||||
functor::SetZeroFunctor<Device, bool>()(eigen_device,
|
||||
out->flat<bool>());
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in0 = state.in0;
|
||||
auto& in1 = state.in1;
|
||||
if (state.out_num_elements == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ndims = state.ndims;
|
||||
const Device& eigen_device = ctx->eigen_device<Device>();
|
||||
bool error = false;
|
||||
bool* const error_ptr = Functor::has_errors ? &error : nullptr;
|
||||
if (ndims <= 1) {
|
||||
@ -122,32 +136,32 @@ class BinaryOp : public BinaryOpShared {
|
||||
}
|
||||
} else if (ndims == 2) {
|
||||
functor::BinaryFunctor<Device, Functor, 2>().BCast(
|
||||
eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
|
||||
in0.template shaped<Tin, 2>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<2>(bcast->x_bcast()),
|
||||
in1.template shaped<Tin, 2>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
|
||||
eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
|
||||
in0.template shaped<Tin, 2>(bcast.x_reshape()),
|
||||
BCast::ToIndexArray<2>(bcast.x_bcast()),
|
||||
in1.template shaped<Tin, 2>(bcast.y_reshape()),
|
||||
BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
|
||||
} else if (ndims == 3) {
|
||||
functor::BinaryFunctor<Device, Functor, 3>().BCast(
|
||||
eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
|
||||
in0.template shaped<Tin, 3>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<3>(bcast->x_bcast()),
|
||||
in1.template shaped<Tin, 3>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
|
||||
eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
|
||||
in0.template shaped<Tin, 3>(bcast.x_reshape()),
|
||||
BCast::ToIndexArray<3>(bcast.x_bcast()),
|
||||
in1.template shaped<Tin, 3>(bcast.y_reshape()),
|
||||
BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
|
||||
} else if (ndims == 4) {
|
||||
functor::BinaryFunctor<Device, Functor, 4>().BCast(
|
||||
eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
|
||||
in0.template shaped<Tin, 4>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<4>(bcast->x_bcast()),
|
||||
in1.template shaped<Tin, 4>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
|
||||
eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
|
||||
in0.template shaped<Tin, 4>(bcast.x_reshape()),
|
||||
BCast::ToIndexArray<4>(bcast.x_bcast()),
|
||||
in1.template shaped<Tin, 4>(bcast.y_reshape()),
|
||||
BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
|
||||
} else if (ndims == 5) {
|
||||
functor::BinaryFunctor<Device, Functor, 5>().BCast(
|
||||
eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
|
||||
in0.template shaped<Tin, 5>(bcast->x_reshape()),
|
||||
BCast::ToIndexArray<5>(bcast->x_bcast()),
|
||||
in1.template shaped<Tin, 5>(bcast->y_reshape()),
|
||||
BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
|
||||
eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
|
||||
in0.template shaped<Tin, 5>(bcast.x_reshape()),
|
||||
BCast::ToIndexArray<5>(bcast.x_bcast()),
|
||||
in1.template shaped<Tin, 5>(bcast.y_reshape()),
|
||||
BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
|
||||
} else {
|
||||
SetUnimplementedError(ctx);
|
||||
}
|
||||
|
@ -700,7 +700,19 @@ REGISTER_OP("GreaterEqual").COMPARISON();
|
||||
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
|
||||
"int64, complex64, quint8, qint8, qint32, string, bool, " \
|
||||
"complex128}") \
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||
.Attr("incompatible_shape_error: bool = true") \
|
||||
.SetShapeFn([](InferenceContext* c) { \
|
||||
ShapeHandle x = c->input(0); \
|
||||
ShapeHandle y = c->input(1); \
|
||||
ShapeHandle output; \
|
||||
bool incompatible_shape_error; \
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \
|
||||
&incompatible_shape_error)); \
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \
|
||||
c, x, y, incompatible_shape_error, &output)); \
|
||||
c->set_output(0, output); \
|
||||
return Status::OK(); \
|
||||
})
|
||||
|
||||
REGISTER_OP("Equal").EQUALITY_COMPARISON();
|
||||
|
||||
@ -877,10 +889,10 @@ REGISTER_OP("SelectV2")
|
||||
ShapeHandle else_ = c->input(2);
|
||||
ShapeHandle other;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other));
|
||||
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other));
|
||||
ShapeHandle output;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output));
|
||||
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output));
|
||||
c->set_output(0, output);
|
||||
return Status::OK();
|
||||
});
|
||||
|
@ -110,28 +110,17 @@ TEST(MathOpsTest, Segment_ShapeFn) {
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
|
||||
for (const auto* op_name : {"Add", "Complex",
|
||||
"Div", "Equal",
|
||||
"Greater", "GreaterEqual",
|
||||
"Igamma", "Igammac",
|
||||
"Zeta", "Polygamma",
|
||||
"Less", "LessEqual",
|
||||
"LogicalAnd", "LogicalOr",
|
||||
"Maximum", "Minimum",
|
||||
"Mod", "Mul",
|
||||
"NotEqual", "Pow",
|
||||
"Sub", "SquaredDifference",
|
||||
"DivNoNan"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
auto test_shapes = [&](ShapeInferenceTestOp& op,
|
||||
bool incompatible_shape_error) {
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[1,2];?", "?");
|
||||
INFER_OK(op, "?;[1,2]", "?");
|
||||
|
||||
INFER_OK(op, "[?];[1]", "[d0_0]");
|
||||
INFER_OK(op, "[1];[?]", "[d1_0]");
|
||||
INFER_OK(op, "[?];[2]", "[d1_0]");
|
||||
INFER_OK(op, "[2];[?]", "[d0_0]");
|
||||
INFER_OK(op, "[?];[?]", "[?]");
|
||||
INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?");
|
||||
INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
|
||||
INFER_OK(op, "[?];[?]", incompatible_shape_error ? "[?]" : "?");
|
||||
INFER_OK(op, "[];[?]", "[d1_0]");
|
||||
INFER_OK(op, "[?];[]", "[d0_0]");
|
||||
|
||||
@ -144,7 +133,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
|
||||
INFER_OK(op, "[1];[2]", "[d1_0]");
|
||||
INFER_OK(op, "[2];[1]", "[d0_0]");
|
||||
INFER_OK(op, "[2];[]", "[d0_0]");
|
||||
INFER_OK(op, "[2];[?]", "[d0_0]");
|
||||
INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
|
||||
|
||||
INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
|
||||
INFER_OK(op, "[];[0]", "[d1_0]");
|
||||
@ -152,14 +141,46 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
|
||||
INFER_OK(op, "[0];[1]", "[d0_0]");
|
||||
INFER_OK(op, "[0];[]", "[d0_0]");
|
||||
|
||||
INFER_OK(op, "[2];[?,?]", "[d1_0,d0_0]");
|
||||
INFER_OK(op, "[2,2];[?,?,?]", "[d1_0,d0_0,d0_1]");
|
||||
INFER_OK(op, "[2];[?,?]", incompatible_shape_error ? "[d1_0,d0_0]" : "?");
|
||||
INFER_OK(op, "[2,2];[?,?,?]",
|
||||
incompatible_shape_error ? "[d1_0,d0_0,d0_1]" : "?");
|
||||
|
||||
// Multiple dimension cases (same test cases, switching x and y).
|
||||
INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
|
||||
"[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
|
||||
incompatible_shape_error ? "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"
|
||||
: "?");
|
||||
INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
|
||||
"[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
|
||||
incompatible_shape_error ? "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"
|
||||
: "?");
|
||||
|
||||
if (incompatible_shape_error) {
|
||||
INFER_ERROR("Dimensions must be equal", op, "[2];[3]");
|
||||
} else {
|
||||
INFER_OK(op, "[2];[3]", "[]");
|
||||
}
|
||||
};
|
||||
|
||||
for (string op_name : {"Add", "Complex",
|
||||
"Div", "Equal",
|
||||
"Greater", "GreaterEqual",
|
||||
"Igamma", "Igammac",
|
||||
"Zeta", "Polygamma",
|
||||
"Less", "LessEqual",
|
||||
"LogicalAnd", "LogicalOr",
|
||||
"Maximum", "Minimum",
|
||||
"Mod", "Mul",
|
||||
"NotEqual", "Pow",
|
||||
"Sub", "SquaredDifference",
|
||||
"DivNoNan"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
AddNodeAttr("incompatible_shape_error", true, &op.node_def);
|
||||
test_shapes(op, true);
|
||||
|
||||
if ((op_name == "Equal") || (op_name == "NotEqual")) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
AddNodeAttr("incompatible_shape_error", false, &op.node_def);
|
||||
test_shapes(op, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import core
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -302,9 +303,13 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
bool(tf_a == tf_d)
|
||||
self.assertAllEqual(tf_a == tf_d, [[True, False], [True, False]])
|
||||
# TODO(b/120678848): If shapes do not match we should instead return False
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
bool(tf_a != tf_e)
|
||||
if compat.forward_compatible(2019, 9, 25):
|
||||
self.assertFalse(bool(tf_a == tf_e))
|
||||
self.assertTrue(bool(tf_a != tf_e))
|
||||
self.assertNotAllEqual(tf_a, tf_e)
|
||||
else:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
bool(tf_a != tf_e)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
bool(np_a == np_b)
|
||||
@ -313,7 +318,9 @@ class TFETest(test_util.TensorFlowTestCase):
|
||||
bool(np_a == np_c)
|
||||
self.assertAllEqual(np_a == np_c, [[True, True], [True, True]])
|
||||
self.assertAllEqual(np_a == np_d, [[True, False], [True, False]])
|
||||
bool(np_a != np_e)
|
||||
self.assertFalse(bool(np_a == np_e))
|
||||
self.assertTrue(bool(np_a != np_e))
|
||||
self.assertNotAllEqual(np_a, np_e)
|
||||
finally:
|
||||
if default:
|
||||
ops.enable_tensor_equality()
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -255,10 +256,9 @@ class BinaryOpTest(test.TestCase):
|
||||
var_x = variables.Variable(x)
|
||||
var_y = variables.Variable(y)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.evaluate([var_x.initializer, var_y.initializer])
|
||||
left_result = self.evaluate(var_x * y)
|
||||
right_result = self.evaluate(x * var_y)
|
||||
self.evaluate([var_x.initializer, var_y.initializer])
|
||||
left_result = self.evaluate(var_x * y)
|
||||
right_result = self.evaluate(x * var_y)
|
||||
|
||||
np_result = x * y
|
||||
self.assertAllEqual(np_result, left_result)
|
||||
@ -933,7 +933,6 @@ class ComparisonOpTest(test.TestCase):
|
||||
self._testBCastByFunc(
|
||||
np.not_equal, math_ops.not_equal, include_complex=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapeMismatch(self):
|
||||
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
|
||||
funcs = [
|
||||
@ -944,8 +943,9 @@ class ComparisonOpTest(test.TestCase):
|
||||
y = np.arange(0, 10).reshape([5, 2])
|
||||
for t in dtypes:
|
||||
for f in funcs:
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError, lambda e: "Dimensions must" in str(e)):
|
||||
with self.assertRaisesRegexp(
|
||||
(ValueError, errors.InvalidArgumentError),
|
||||
"Incompatible shapes|Dimensions must be equal"):
|
||||
f(x.astype(t), y.astype(t))
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -201,7 +202,6 @@ class ComparisonOpTest(test.TestCase):
|
||||
self._testBCastByFunc(
|
||||
np.not_equal, math_ops.not_equal, include_complex=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapeMismatch(self):
|
||||
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
|
||||
funcs = [
|
||||
@ -212,8 +212,9 @@ class ComparisonOpTest(test.TestCase):
|
||||
y = np.arange(0, 10).reshape([5, 2])
|
||||
for t in dtypes:
|
||||
for f in funcs:
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError, lambda e: "Dimensions must" in str(e)):
|
||||
with self.assertRaisesRegexp(
|
||||
(ValueError, errors.InvalidArgumentError),
|
||||
"Incompatible shapes|Dimensions must be equal"):
|
||||
f(x.astype(t), y.astype(t))
|
||||
|
||||
|
||||
|
@ -310,7 +310,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
|
||||
static_func: Function that, if passed numpy ndarray versions of the two
|
||||
inputs to the assertion, will return a Boolean ndarray with containing
|
||||
True in all positions where the assertion PASSES.
|
||||
i.e. lambda x,y: (x == y) for assert_equal()
|
||||
i.e. np.equal for assert_equal()
|
||||
x: Numeric `Tensor`.
|
||||
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
|
||||
data: The tensors to print out if the condition is False. Defaults to
|
||||
@ -366,7 +366,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
|
||||
x_static = tensor_util.constant_value(x)
|
||||
y_static = tensor_util.constant_value(y)
|
||||
if x_static is not None and y_static is not None:
|
||||
condition_static = static_func(x_static, y_static).all()
|
||||
condition_static = np.all(static_func(x_static, y_static))
|
||||
_assert_static(condition_static, data)
|
||||
return control_flow_ops.Assert(condition, data, summarize=summarize)
|
||||
|
||||
@ -654,9 +654,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # p
|
||||
# Short-circuit if x and y are the same tensor.
|
||||
if x is y:
|
||||
return None if context.executing_eagerly() else control_flow_ops.no_op()
|
||||
return _binary_assert('==', 'assert_equal', math_ops.equal,
|
||||
lambda x, y: (x == y),
|
||||
x, y, data, summarize, message, name)
|
||||
return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
|
||||
data, summarize, message, name)
|
||||
|
||||
|
||||
@tf_export('debugging.assert_none_equal', v1=[])
|
||||
@ -703,8 +702,7 @@ def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
|
||||
def assert_none_equal(
|
||||
x, y, data=None, summarize=None, message=None, name=None):
|
||||
return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
|
||||
lambda x, y: (x != y), x, y, data, summarize, message,
|
||||
name)
|
||||
np.not_equal, x, y, data, summarize, message, name)
|
||||
|
||||
|
||||
@tf_export('debugging.assert_near', v1=[])
|
||||
@ -877,8 +875,8 @@ def assert_less_v2(x, y, message=None, summarize=None, name=None):
|
||||
@tf_export(v1=['debugging.assert_less', 'assert_less'])
|
||||
@_binary_assert_doc('<')
|
||||
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
|
||||
return _binary_assert('<', 'assert_less', math_ops.less, lambda x, y: (x < y),
|
||||
x, y, data, summarize, message, name)
|
||||
return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
|
||||
summarize, message, name)
|
||||
|
||||
|
||||
@tf_export('debugging.assert_less_equal', v1=[])
|
||||
@ -922,8 +920,7 @@ def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
|
||||
@_binary_assert_doc('<=')
|
||||
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
|
||||
return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
|
||||
lambda x, y: (x <= y), x, y, data, summarize, message,
|
||||
name)
|
||||
np.less_equal, x, y, data, summarize, message, name)
|
||||
|
||||
|
||||
@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
|
||||
@ -965,9 +962,8 @@ def assert_greater_v2(x, y, message=None, summarize=None, name=None):
|
||||
@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
|
||||
@_binary_assert_doc('>')
|
||||
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
|
||||
return _binary_assert('>', 'assert_greater', math_ops.greater,
|
||||
lambda x, y: (x > y),
|
||||
x, y, data, summarize, message, name)
|
||||
return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
|
||||
y, data, summarize, message, name)
|
||||
|
||||
|
||||
@tf_export('debugging.assert_greater_equal', v1=[])
|
||||
@ -1013,8 +1009,7 @@ def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
|
||||
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
|
||||
name=None):
|
||||
return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
|
||||
lambda x, y: (x >= y), x, y, data, summarize, message,
|
||||
name)
|
||||
np.greater_equal, x, y, data, summarize, message, name)
|
||||
|
||||
|
||||
def _assert_rank_condition(
|
||||
@ -1026,8 +1021,8 @@ def _assert_rank_condition(
|
||||
rank: Scalar `Tensor`.
|
||||
static_condition: A python function that takes `[actual_rank, given_rank]`
|
||||
and returns `True` if the condition is satisfied, `False` otherwise.
|
||||
dynamic_condition: An `op` that takes [actual_rank, given_rank]
|
||||
and return `True` if the condition is satisfied, `False` otherwise.
|
||||
dynamic_condition: An `op` that takes [actual_rank, given_rank] and return
|
||||
`True` if the condition is satisfied, `False` otherwise.
|
||||
data: The tensors to print out if the condition is false. Defaults to
|
||||
error message and first few entries of `x`.
|
||||
summarize: Print this many entries of each tensor.
|
||||
@ -2159,4 +2154,3 @@ def ensure_shape(x, shape, name=None):
|
||||
def _ensure_shape_grad(op, grad):
|
||||
del op # Unused.
|
||||
return grad
|
||||
|
||||
|
@ -1270,12 +1270,65 @@ ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
|
||||
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
|
||||
|
||||
|
||||
@tf_export("math.equal", "equal")
|
||||
@dispatch.add_dispatch_support
|
||||
def equal(x, y, name=None):
|
||||
"""Returns the truth value of (x == y) element-wise.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
x = tf.constant([2, 4])
|
||||
y = tf.constant(2)
|
||||
tf.math.equal(x, y) ==> array([True, False])
|
||||
|
||||
x = tf.constant([2, 4])
|
||||
y = tf.constant([2, 4])
|
||||
tf.math.equal(x, y) ==> array([True, True])
|
||||
```
|
||||
|
||||
**NOTE**: `Equal` supports broadcasting. More about broadcasting [here](
|
||||
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
|
||||
|
||||
Args:
|
||||
x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
|
||||
y: A `Tensor` or `SparseTensor` or `IndexedSlices`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` of type bool with the same size as that of x or y.
|
||||
"""
|
||||
return gen_math_ops.equal(x, y, name=name)
|
||||
|
||||
|
||||
@tf_export("math.not_equal", "not_equal")
|
||||
@dispatch.add_dispatch_support
|
||||
def not_equal(x, y, name=None):
|
||||
"""Returns the truth value of (x != y) element-wise.
|
||||
|
||||
**NOTE**: `NotEqual` supports broadcasting. More about broadcasting [here](
|
||||
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
|
||||
|
||||
Args:
|
||||
x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
|
||||
y: A `Tensor` or `SparseTensor` or `IndexedSlices`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` of type bool with the same size as that of x or y.
|
||||
"""
|
||||
return gen_math_ops.not_equal(x, y, name=name)
|
||||
|
||||
|
||||
def tensor_equals(self, other):
|
||||
"""Compares two tensors element-wise for equality."""
|
||||
g = getattr(self, "graph", None)
|
||||
if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
|
||||
(g is None or g._building_function)): # pylint: disable=protected-access
|
||||
return gen_math_ops.equal(self, other)
|
||||
if fwd_compat.forward_compatible(2019, 9, 25):
|
||||
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
|
||||
else:
|
||||
return gen_math_ops.equal(self, other)
|
||||
else:
|
||||
# In legacy graph mode, tensor equality is object equality
|
||||
return self is other
|
||||
@ -1284,7 +1337,10 @@ def tensor_equals(self, other):
|
||||
def tensor_not_equals(self, other):
|
||||
"""Compares two tensors element-wise for equality."""
|
||||
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
|
||||
return gen_math_ops.not_equal(self, other)
|
||||
if fwd_compat.forward_compatible(2019, 9, 25):
|
||||
return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
|
||||
else:
|
||||
return gen_math_ops.not_equal(self, other)
|
||||
else:
|
||||
# In legacy graph mode, tensor equality is object equality
|
||||
return self is not other
|
||||
|
@ -2397,7 +2397,6 @@ def _convert_cast(pfor_input):
|
||||
@RegisterPForWithArgs("Div", math_ops.div)
|
||||
@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
|
||||
@RegisterPForWithArgs("Elu", nn_ops.elu)
|
||||
@RegisterPForWithArgs("Equal", math_ops.equal)
|
||||
@RegisterPForWithArgs("Erf", math_ops.erf)
|
||||
@RegisterPForWithArgs("Erfc", math_ops.erfc)
|
||||
@RegisterPForWithArgs("Exp", math_ops.exp)
|
||||
@ -2432,7 +2431,6 @@ def _convert_cast(pfor_input):
|
||||
@RegisterPForWithArgs("Mul", math_ops.multiply)
|
||||
@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
|
||||
@RegisterPForWithArgs("Neg", math_ops.negative)
|
||||
@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
|
||||
@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
|
||||
@RegisterPForWithArgs("Pow", math_ops.pow)
|
||||
@RegisterPForWithArgs("Real", math_ops.real)
|
||||
@ -2471,6 +2469,26 @@ def _convert_cwise(pfor_input, op_type, op_func):
|
||||
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
|
||||
|
||||
|
||||
@RegisterPFor("Equal")
|
||||
def _convert_equal(pfor_input):
|
||||
pfor_input.expanddim_inputs_for_broadcast()
|
||||
x = pfor_input.input(0)[0]
|
||||
y = pfor_input.input(1)[0]
|
||||
incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
|
||||
assert incompatible_shape_error
|
||||
return wrap(math_ops.equal(x, y), True)
|
||||
|
||||
|
||||
@RegisterPFor("NotEqual")
|
||||
def _convert_not_equal(pfor_input):
|
||||
pfor_input.expanddim_inputs_for_broadcast()
|
||||
x = pfor_input.input(0)[0]
|
||||
y = pfor_input.input(1)[0]
|
||||
incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
|
||||
assert incompatible_shape_error
|
||||
return wrap(math_ops.not_equal(x, y), True)
|
||||
|
||||
|
||||
@RegisterPFor("ApproximateEqual")
|
||||
def _convert_approximate_equal(pfor_input):
|
||||
pfor_input.expanddim_inputs_for_broadcast()
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import variable_pb2
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.compat import compat as fwd_compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -35,8 +36,8 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_state_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -1092,7 +1093,10 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
||||
def __eq__(self, other):
|
||||
"""Compares two variables element-wise for equality."""
|
||||
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
|
||||
return gen_math_ops.equal(self, other)
|
||||
if fwd_compat.forward_compatible(2019, 9, 25):
|
||||
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
|
||||
else:
|
||||
return gen_math_ops.equal(self, other)
|
||||
else:
|
||||
# In legacy graph mode, tensor equality is object equality
|
||||
return self is other
|
||||
@ -1101,7 +1105,11 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
||||
def __ne__(self, other):
|
||||
"""Compares two variables element-wise for equality."""
|
||||
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
|
||||
return gen_math_ops.not_equal(self, other)
|
||||
if fwd_compat.forward_compatible(2019, 9, 25):
|
||||
return gen_math_ops.not_equal(
|
||||
self, other, incompatible_shape_error=False)
|
||||
else:
|
||||
return gen_math_ops.not_equal(self, other)
|
||||
else:
|
||||
# In legacy graph mode, tensor equality is object equality
|
||||
return self is not other
|
||||
|
@ -1174,7 +1174,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Equal"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Erf"
|
||||
@ -2398,7 +2398,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "NotEqual"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "NthElement"
|
||||
|
@ -1174,7 +1174,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Equal"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Erf"
|
||||
@ -2398,7 +2398,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "NotEqual"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "NthElement"
|
||||
|
Loading…
Reference in New Issue
Block a user