Add incompatible_shape_error attribute to equal op

When tensor equality is enabled, if there is an incompatible shape we
currently throw and exception. Ideally we'd like to return False when
calling __eq__ and True when calling __ne__. We thus modify the Equal
and NotEqual ops to return a boolean upon a shape incompatibility. Due
to this change the shape inference logic needs to be changed to either
return a scalar bool if the shapes are incompatible, or else return an
unknown shape to allow for either a boolean Tensor or scalar to be
returned.

Note the behavior of tf.math.equal & tf.math.not_equal is unchanged as
they both use optimistic shape inference logic when dealing with unknown
dimensions which allows for more efficient graphs rather than inserting
Rank operations.

This distinction between __eq__ & tf.math.equal is also found in numpy
and as a result the tf.debugging.assert_equal and
tf.debugging.assert_none_equal APIs needed to be change to utilize the
numpy operations.

PiperOrigin-RevId: 267466043
This commit is contained in:
Gaurav Jain 2019-09-05 15:15:06 -07:00 committed by TensorFlower Gardener
parent 792abd2eaf
commit e0e1efbe08
17 changed files with 267 additions and 112 deletions

View File

@ -1,9 +1,4 @@
op {
graph_op_name: "Equal"
endpoint {
name: "math.equal"
}
endpoint {
name: "equal"
}
visibility: HIDDEN
}

View File

@ -1,9 +1,4 @@
op {
graph_op_name: "NotEqual"
endpoint {
name: "math.not_equal"
}
endpoint {
name: "not_equal"
}
visibility: HIDDEN
}

View File

@ -366,7 +366,8 @@ Status EinsumShape(shape_inference::InferenceContext* c) {
output_bcast_shape = input_bcast_shapes[0];
} else if (input_bcast_shapes.size() == 2) {
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, input_bcast_shapes[0], input_bcast_shapes[1], &output_bcast_shape));
c, input_bcast_shapes[0], input_bcast_shapes[1], true,
&output_bcast_shape));
}
bool output_has_ellipsis = false;
@ -441,7 +442,7 @@ Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, a_batch_shape, b_batch_shape, &output_batch_shape));
c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(c->Concatenate(
@ -1633,6 +1634,7 @@ Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
ShapeHandle shape_x,
ShapeHandle shape_y,
bool incompatible_shape_error,
ShapeHandle* out) {
CHECK_NOTNULL(out);
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
@ -1666,8 +1668,16 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
// or the same as the known dim.
// - If either dimension is 1, the other dimension is the output.
if (c->Value(dim_x) > 1) {
if (!incompatible_shape_error) {
*out = c->UnknownShape();
return Status::OK();
}
dims.push_back(dim_x);
} else if (c->Value(dim_y) > 1) {
if (!incompatible_shape_error) {
*out = c->UnknownShape();
return Status::OK();
}
dims.push_back(dim_y);
} else if (c->Value(dim_x) == 1) {
dims.push_back(dim_y);
@ -1676,6 +1686,10 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
} else if (dim_y.SameHandle(dim_x)) {
dims.push_back(dim_x);
} else {
if (!incompatible_shape_error) {
*out = c->UnknownShape();
return Status::OK();
}
dims.push_back(c->UnknownDim());
}
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
@ -1689,7 +1703,14 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
}
} else {
DimensionHandle dim;
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
Status s = c->Merge(dim_x, dim_y, &dim);
if (!s.ok()) {
if (!incompatible_shape_error) {
*out = c->MakeShape({});
return Status::OK();
}
return s;
}
dims.push_back(dim);
}
}

View File

@ -306,6 +306,7 @@ Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
ShapeHandle shape_x,
ShapeHandle shape_y,
bool incompatible_shape_error,
ShapeHandle* out);
// Shape function for binary operators that broadcast their inputs
@ -313,8 +314,8 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
int output_index) {
ShapeHandle out;
TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, c->input(0), c->input(1), true, &out));
c->set_output(output_index, out);
return Status::OK();
}

View File

@ -57,11 +57,23 @@ BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
in1(ctx->input(1)),
bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) {
if (!bcast.IsValid()) {
bool incompatible_shape_error;
bool has_attr =
TryGetNodeAttr(ctx->op_kernel().def(), "incompatible_shape_error",
&(incompatible_shape_error));
if (has_attr && !incompatible_shape_error) {
const string& op = ctx->op_kernel().type_string();
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
result = (op == "NotEqual");
return;
}
ctx->SetStatus(errors::InvalidArgument(
"Incompatible shapes: ", in0.shape().DebugString(), " vs. ",
in1.shape().DebugString()));
return;
}
const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
out_num_elements = output_shape.num_elements();
in0_num_elements = in0.NumElements();

View File

@ -26,13 +26,13 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
#endif
#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/kernels/cwise_ops_gradients.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/bcast.h"
@ -56,7 +56,7 @@ class BinaryOpShared : public OpKernel {
// in-place computation.
// Caller must check ctx->status() upon return for non-ok status.
// If ctx->status().ok() is true, then out is guaranteed to be allocated.
BinaryOpState(OpKernelContext* ctx);
explicit BinaryOpState(OpKernelContext* ctx);
const Tensor& in0;
const Tensor& in1;
@ -69,6 +69,7 @@ class BinaryOpShared : public OpKernel {
int64 in1_num_elements;
int ndims;
bool result;
};
void SetUnimplementedError(OpKernelContext* ctx);
@ -91,16 +92,29 @@ class BinaryOp : public BinaryOpShared {
void Compute(OpKernelContext* ctx) override {
// 'state': Shared helper not dependent on T to reduce code size
BinaryOpState state(ctx);
if (!ctx->status().ok()) return;
auto& bcast = state.bcast;
const Device& eigen_device = ctx->eigen_device<Device>();
Tensor* out = state.out;
BCast* bcast = &state.bcast;
if (!bcast.IsValid()) {
if (ctx->status().ok()) {
if (state.result) {
functor::SetOneFunctor<Device, bool>()(eigen_device,
out->flat<bool>());
} else {
functor::SetZeroFunctor<Device, bool>()(eigen_device,
out->flat<bool>());
}
}
return;
}
auto& in0 = state.in0;
auto& in1 = state.in1;
if (state.out_num_elements == 0) {
return;
}
const int ndims = state.ndims;
const Device& eigen_device = ctx->eigen_device<Device>();
bool error = false;
bool* const error_ptr = Functor::has_errors ? &error : nullptr;
if (ndims <= 1) {
@ -122,32 +136,32 @@ class BinaryOp : public BinaryOpShared {
}
} else if (ndims == 2) {
functor::BinaryFunctor<Device, Functor, 2>().BCast(
eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
in0.template shaped<Tin, 2>(bcast->x_reshape()),
BCast::ToIndexArray<2>(bcast->x_bcast()),
in1.template shaped<Tin, 2>(bcast->y_reshape()),
BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
eigen_device, out->shaped<Tout, 2>(bcast.result_shape()),
in0.template shaped<Tin, 2>(bcast.x_reshape()),
BCast::ToIndexArray<2>(bcast.x_bcast()),
in1.template shaped<Tin, 2>(bcast.y_reshape()),
BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr);
} else if (ndims == 3) {
functor::BinaryFunctor<Device, Functor, 3>().BCast(
eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
in0.template shaped<Tin, 3>(bcast->x_reshape()),
BCast::ToIndexArray<3>(bcast->x_bcast()),
in1.template shaped<Tin, 3>(bcast->y_reshape()),
BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
eigen_device, out->shaped<Tout, 3>(bcast.result_shape()),
in0.template shaped<Tin, 3>(bcast.x_reshape()),
BCast::ToIndexArray<3>(bcast.x_bcast()),
in1.template shaped<Tin, 3>(bcast.y_reshape()),
BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr);
} else if (ndims == 4) {
functor::BinaryFunctor<Device, Functor, 4>().BCast(
eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
in0.template shaped<Tin, 4>(bcast->x_reshape()),
BCast::ToIndexArray<4>(bcast->x_bcast()),
in1.template shaped<Tin, 4>(bcast->y_reshape()),
BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
eigen_device, out->shaped<Tout, 4>(bcast.result_shape()),
in0.template shaped<Tin, 4>(bcast.x_reshape()),
BCast::ToIndexArray<4>(bcast.x_bcast()),
in1.template shaped<Tin, 4>(bcast.y_reshape()),
BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr);
} else if (ndims == 5) {
functor::BinaryFunctor<Device, Functor, 5>().BCast(
eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
in0.template shaped<Tin, 5>(bcast->x_reshape()),
BCast::ToIndexArray<5>(bcast->x_bcast()),
in1.template shaped<Tin, 5>(bcast->y_reshape()),
BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
eigen_device, out->shaped<Tout, 5>(bcast.result_shape()),
in0.template shaped<Tin, 5>(bcast.x_reshape()),
BCast::ToIndexArray<5>(bcast.x_bcast()),
in1.template shaped<Tin, 5>(bcast.y_reshape()),
BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr);
} else {
SetUnimplementedError(ctx);
}

View File

@ -700,7 +700,19 @@ REGISTER_OP("GreaterEqual").COMPARISON();
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
"int64, complex64, quint8, qint8, qint32, string, bool, " \
"complex128}") \
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Attr("incompatible_shape_error: bool = true") \
.SetShapeFn([](InferenceContext* c) { \
ShapeHandle x = c->input(0); \
ShapeHandle y = c->input(1); \
ShapeHandle output; \
bool incompatible_shape_error; \
TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \
&incompatible_shape_error)); \
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \
c, x, y, incompatible_shape_error, &output)); \
c->set_output(0, output); \
return Status::OK(); \
})
REGISTER_OP("Equal").EQUALITY_COMPARISON();
@ -877,10 +889,10 @@ REGISTER_OP("SelectV2")
ShapeHandle else_ = c->input(2);
ShapeHandle other;
TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other));
BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other));
ShapeHandle output;
TF_RETURN_IF_ERROR(
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output));
BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output));
c->set_output(0, output);
return Status::OK();
});

View File

@ -110,28 +110,17 @@ TEST(MathOpsTest, Segment_ShapeFn) {
}
TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
for (const auto* op_name : {"Add", "Complex",
"Div", "Equal",
"Greater", "GreaterEqual",
"Igamma", "Igammac",
"Zeta", "Polygamma",
"Less", "LessEqual",
"LogicalAnd", "LogicalOr",
"Maximum", "Minimum",
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
"DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
auto test_shapes = [&](ShapeInferenceTestOp& op,
bool incompatible_shape_error) {
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
INFER_OK(op, "?;[1,2]", "?");
INFER_OK(op, "[?];[1]", "[d0_0]");
INFER_OK(op, "[1];[?]", "[d1_0]");
INFER_OK(op, "[?];[2]", "[d1_0]");
INFER_OK(op, "[2];[?]", "[d0_0]");
INFER_OK(op, "[?];[?]", "[?]");
INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?");
INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
INFER_OK(op, "[?];[?]", incompatible_shape_error ? "[?]" : "?");
INFER_OK(op, "[];[?]", "[d1_0]");
INFER_OK(op, "[?];[]", "[d0_0]");
@ -144,7 +133,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
INFER_OK(op, "[1];[2]", "[d1_0]");
INFER_OK(op, "[2];[1]", "[d0_0]");
INFER_OK(op, "[2];[]", "[d0_0]");
INFER_OK(op, "[2];[?]", "[d0_0]");
INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
INFER_OK(op, "[];[0]", "[d1_0]");
@ -152,14 +141,46 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
INFER_OK(op, "[0];[1]", "[d0_0]");
INFER_OK(op, "[0];[]", "[d0_0]");
INFER_OK(op, "[2];[?,?]", "[d1_0,d0_0]");
INFER_OK(op, "[2,2];[?,?,?]", "[d1_0,d0_0,d0_1]");
INFER_OK(op, "[2];[?,?]", incompatible_shape_error ? "[d1_0,d0_0]" : "?");
INFER_OK(op, "[2,2];[?,?,?]",
incompatible_shape_error ? "[d1_0,d0_0,d0_1]" : "?");
// Multiple dimension cases (same test cases, switching x and y).
INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
"[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
incompatible_shape_error ? "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"
: "?");
INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
"[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
incompatible_shape_error ? "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"
: "?");
if (incompatible_shape_error) {
INFER_ERROR("Dimensions must be equal", op, "[2];[3]");
} else {
INFER_OK(op, "[2];[3]", "[]");
}
};
for (string op_name : {"Add", "Complex",
"Div", "Equal",
"Greater", "GreaterEqual",
"Igamma", "Igammac",
"Zeta", "Polygamma",
"Less", "LessEqual",
"LogicalAnd", "LogicalOr",
"Maximum", "Minimum",
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
"DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
AddNodeAttr("incompatible_shape_error", true, &op.node_def);
test_shapes(op, true);
if ((op_name == "Equal") || (op_name == "NotEqual")) {
ShapeInferenceTestOp op(op_name);
AddNodeAttr("incompatible_shape_error", false, &op.node_def);
test_shapes(op, false);
}
}
}

View File

@ -27,6 +27,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import def_function
@ -302,9 +303,13 @@ class TFETest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
bool(tf_a == tf_d)
self.assertAllEqual(tf_a == tf_d, [[True, False], [True, False]])
# TODO(b/120678848): If shapes do not match we should instead return False
with self.assertRaises(errors.InvalidArgumentError):
bool(tf_a != tf_e)
if compat.forward_compatible(2019, 9, 25):
self.assertFalse(bool(tf_a == tf_e))
self.assertTrue(bool(tf_a != tf_e))
self.assertNotAllEqual(tf_a, tf_e)
else:
with self.assertRaises(errors.InvalidArgumentError):
bool(tf_a != tf_e)
with self.assertRaises(ValueError):
bool(np_a == np_b)
@ -313,7 +318,9 @@ class TFETest(test_util.TensorFlowTestCase):
bool(np_a == np_c)
self.assertAllEqual(np_a == np_c, [[True, True], [True, True]])
self.assertAllEqual(np_a == np_d, [[True, False], [True, False]])
bool(np_a != np_e)
self.assertFalse(bool(np_a == np_e))
self.assertTrue(bool(np_a != np_e))
self.assertNotAllEqual(np_a, np_e)
finally:
if default:
ops.enable_tensor_equality()

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -255,10 +256,9 @@ class BinaryOpTest(test.TestCase):
var_x = variables.Variable(x)
var_y = variables.Variable(y)
with self.cached_session() as sess:
self.evaluate([var_x.initializer, var_y.initializer])
left_result = self.evaluate(var_x * y)
right_result = self.evaluate(x * var_y)
self.evaluate([var_x.initializer, var_y.initializer])
left_result = self.evaluate(var_x * y)
right_result = self.evaluate(x * var_y)
np_result = x * y
self.assertAllEqual(np_result, left_result)
@ -933,7 +933,6 @@ class ComparisonOpTest(test.TestCase):
self._testBCastByFunc(
np.not_equal, math_ops.not_equal, include_complex=True)
@test_util.run_deprecated_v1
def testShapeMismatch(self):
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
funcs = [
@ -944,8 +943,9 @@ class ComparisonOpTest(test.TestCase):
y = np.arange(0, 10).reshape([5, 2])
for t in dtypes:
for f in funcs:
with self.assertRaisesWithPredicateMatch(
ValueError, lambda e: "Dimensions must" in str(e)):
with self.assertRaisesRegexp(
(ValueError, errors.InvalidArgumentError),
"Incompatible shapes|Dimensions must be equal"):
f(x.astype(t), y.astype(t))

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
@ -201,7 +202,6 @@ class ComparisonOpTest(test.TestCase):
self._testBCastByFunc(
np.not_equal, math_ops.not_equal, include_complex=True)
@test_util.run_deprecated_v1
def testShapeMismatch(self):
dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
funcs = [
@ -212,8 +212,9 @@ class ComparisonOpTest(test.TestCase):
y = np.arange(0, 10).reshape([5, 2])
for t in dtypes:
for f in funcs:
with self.assertRaisesWithPredicateMatch(
ValueError, lambda e: "Dimensions must" in str(e)):
with self.assertRaisesRegexp(
(ValueError, errors.InvalidArgumentError),
"Incompatible shapes|Dimensions must be equal"):
f(x.astype(t), y.astype(t))

View File

@ -310,7 +310,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
static_func: Function that, if passed numpy ndarray versions of the two
inputs to the assertion, will return a Boolean ndarray with containing
True in all positions where the assertion PASSES.
i.e. lambda x,y: (x == y) for assert_equal()
i.e. np.equal for assert_equal()
x: Numeric `Tensor`.
y: Numeric `Tensor`, same dtype as and broadcastable to `x`.
data: The tensors to print out if the condition is False. Defaults to
@ -366,7 +366,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
x_static = tensor_util.constant_value(x)
y_static = tensor_util.constant_value(y)
if x_static is not None and y_static is not None:
condition_static = static_func(x_static, y_static).all()
condition_static = np.all(static_func(x_static, y_static))
_assert_static(condition_static, data)
return control_flow_ops.Assert(condition, data, summarize=summarize)
@ -654,9 +654,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # p
# Short-circuit if x and y are the same tensor.
if x is y:
return None if context.executing_eagerly() else control_flow_ops.no_op()
return _binary_assert('==', 'assert_equal', math_ops.equal,
lambda x, y: (x == y),
x, y, data, summarize, message, name)
return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
data, summarize, message, name)
@tf_export('debugging.assert_none_equal', v1=[])
@ -703,8 +702,7 @@ def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
def assert_none_equal(
x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
lambda x, y: (x != y), x, y, data, summarize, message,
name)
np.not_equal, x, y, data, summarize, message, name)
@tf_export('debugging.assert_near', v1=[])
@ -877,8 +875,8 @@ def assert_less_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_less', 'assert_less'])
@_binary_assert_doc('<')
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<', 'assert_less', math_ops.less, lambda x, y: (x < y),
x, y, data, summarize, message, name)
return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
summarize, message, name)
@tf_export('debugging.assert_less_equal', v1=[])
@ -922,8 +920,7 @@ def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
@_binary_assert_doc('<=')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
lambda x, y: (x <= y), x, y, data, summarize, message,
name)
np.less_equal, x, y, data, summarize, message, name)
@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
@ -965,9 +962,8 @@ def assert_greater_v2(x, y, message=None, summarize=None, name=None):
@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
@_binary_assert_doc('>')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring
return _binary_assert('>', 'assert_greater', math_ops.greater,
lambda x, y: (x > y),
x, y, data, summarize, message, name)
return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
y, data, summarize, message, name)
@tf_export('debugging.assert_greater_equal', v1=[])
@ -1013,8 +1009,7 @@ def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
name=None):
return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
lambda x, y: (x >= y), x, y, data, summarize, message,
name)
np.greater_equal, x, y, data, summarize, message, name)
def _assert_rank_condition(
@ -1026,8 +1021,8 @@ def _assert_rank_condition(
rank: Scalar `Tensor`.
static_condition: A python function that takes `[actual_rank, given_rank]`
and returns `True` if the condition is satisfied, `False` otherwise.
dynamic_condition: An `op` that takes [actual_rank, given_rank]
and return `True` if the condition is satisfied, `False` otherwise.
dynamic_condition: An `op` that takes [actual_rank, given_rank] and return
`True` if the condition is satisfied, `False` otherwise.
data: The tensors to print out if the condition is false. Defaults to
error message and first few entries of `x`.
summarize: Print this many entries of each tensor.
@ -2159,4 +2154,3 @@ def ensure_shape(x, shape, name=None):
def _ensure_shape_grad(op, grad):
del op # Unused.
return grad

View File

@ -1270,12 +1270,65 @@ ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
@tf_export("math.equal", "equal")
@dispatch.add_dispatch_support
def equal(x, y, name=None):
"""Returns the truth value of (x == y) element-wise.
Usage:
```python
x = tf.constant([2, 4])
y = tf.constant(2)
tf.math.equal(x, y) ==> array([True, False])
x = tf.constant([2, 4])
y = tf.constant([2, 4])
tf.math.equal(x, y) ==> array([True, True])
```
**NOTE**: `Equal` supports broadcasting. More about broadcasting [here](
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
Args:
x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
y: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type bool with the same size as that of x or y.
"""
return gen_math_ops.equal(x, y, name=name)
@tf_export("math.not_equal", "not_equal")
@dispatch.add_dispatch_support
def not_equal(x, y, name=None):
"""Returns the truth value of (x != y) element-wise.
**NOTE**: `NotEqual` supports broadcasting. More about broadcasting [here](
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
Args:
x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
y: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type bool with the same size as that of x or y.
"""
return gen_math_ops.not_equal(x, y, name=name)
def tensor_equals(self, other):
"""Compares two tensors element-wise for equality."""
g = getattr(self, "graph", None)
if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
(g is None or g._building_function)): # pylint: disable=protected-access
return gen_math_ops.equal(self, other)
if fwd_compat.forward_compatible(2019, 9, 25):
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
else:
return gen_math_ops.equal(self, other)
else:
# In legacy graph mode, tensor equality is object equality
return self is other
@ -1284,7 +1337,10 @@ def tensor_equals(self, other):
def tensor_not_equals(self, other):
"""Compares two tensors element-wise for equality."""
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
return gen_math_ops.not_equal(self, other)
if fwd_compat.forward_compatible(2019, 9, 25):
return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
else:
return gen_math_ops.not_equal(self, other)
else:
# In legacy graph mode, tensor equality is object equality
return self is not other

View File

@ -2397,7 +2397,6 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Div", math_ops.div)
@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
@RegisterPForWithArgs("Elu", nn_ops.elu)
@RegisterPForWithArgs("Equal", math_ops.equal)
@RegisterPForWithArgs("Erf", math_ops.erf)
@RegisterPForWithArgs("Erfc", math_ops.erfc)
@RegisterPForWithArgs("Exp", math_ops.exp)
@ -2432,7 +2431,6 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Mul", math_ops.multiply)
@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan)
@RegisterPForWithArgs("Neg", math_ops.negative)
@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
@RegisterPForWithArgs("Pow", math_ops.pow)
@RegisterPForWithArgs("Real", math_ops.real)
@ -2471,6 +2469,26 @@ def _convert_cwise(pfor_input, op_type, op_func):
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
@RegisterPFor("Equal")
def _convert_equal(pfor_input):
pfor_input.expanddim_inputs_for_broadcast()
x = pfor_input.input(0)[0]
y = pfor_input.input(1)[0]
incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
assert incompatible_shape_error
return wrap(math_ops.equal(x, y), True)
@RegisterPFor("NotEqual")
def _convert_not_equal(pfor_input):
pfor_input.expanddim_inputs_for_broadcast()
x = pfor_input.input(0)[0]
y = pfor_input.input(1)[0]
incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
assert incompatible_shape_error
return wrap(math_ops.not_equal(x, y), True)
@RegisterPFor("ApproximateEqual")
def _convert_approximate_equal(pfor_input):
pfor_input.expanddim_inputs_for_broadcast()

View File

@ -28,6 +28,7 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python import _pywrap_utils
from tensorflow.python.compat import compat as fwd_compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -35,8 +36,8 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
@ -1092,7 +1093,10 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
def __eq__(self, other):
"""Compares two variables element-wise for equality."""
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
return gen_math_ops.equal(self, other)
if fwd_compat.forward_compatible(2019, 9, 25):
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
else:
return gen_math_ops.equal(self, other)
else:
# In legacy graph mode, tensor equality is object equality
return self is other
@ -1101,7 +1105,11 @@ class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
def __ne__(self, other):
"""Compares two variables element-wise for equality."""
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): # pylint: disable=protected-access
return gen_math_ops.not_equal(self, other)
if fwd_compat.forward_compatible(2019, 9, 25):
return gen_math_ops.not_equal(
self, other, incompatible_shape_error=False)
else:
return gen_math_ops.not_equal(self, other)
else:
# In legacy graph mode, tensor equality is object equality
return self is not other

View File

@ -1174,7 +1174,7 @@ tf_module {
}
member_method {
name: "Equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "Erf"
@ -2398,7 +2398,7 @@ tf_module {
}
member_method {
name: "NotEqual"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "NthElement"

View File

@ -1174,7 +1174,7 @@ tf_module {
}
member_method {
name: "Equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "Erf"
@ -2398,7 +2398,7 @@ tf_module {
}
member_method {
name: "NotEqual"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'x\', \'y\', \'incompatible_shape_error\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "NthElement"