Export and add dispatch to overloaded tensor operators (that were not already exported under different api symbols) under the hidden namespace tf.__operators__.

This will make these built-in operators more amenable to dispatching for library developers.
This includes:
tf.__operators__.add
tf.__operators__.ne
tf.__operators__.eq
tf.__operators__.getitem

PiperOrigin-RevId: 315998480
Change-Id: Icf61e24a2c037eaf2c4d170967eb2b8ac18f5961
This commit is contained in:
Tomer Kaftan 2020-06-11 16:13:16 -07:00 committed by TensorFlower Gardener
parent e3aa45345b
commit 7933b0e5e3
6 changed files with 123 additions and 3 deletions

View File

@ -212,6 +212,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
func() # Warmup.
self._run(func, 30000)
def _benchmark_add_operator_overload(self, a, b):
def func():
return memoryview(a + b)
with ops.device("GPU:0" if context.num_gpus() else "CPU:0"):
for _ in range(1000):
func() # Warmup.
self._run(func, 30000)
def benchmark_add_float_scalars(self):
self._benchmark_add(42.0, 24.0)
@ -223,6 +232,11 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
tensor_b = constant_op.constant(24.0)
self._benchmark_add(tensor_a, tensor_b)
def benchmark_add_float_scalar_tensor_overloaded_operator(self):
tensor_a = constant_op.constant(42.0)
tensor_b = constant_op.constant(24.0)
self._benchmark_add_operator_overload(tensor_a, tensor_b)
def benchmark_add_int32_scalar_tensor(self):
tensor_a = constant_op.constant(42)
tensor_b = constant_op.constant(24)

View File

@ -869,6 +869,8 @@ def _is_undefined_dimension(d):
return isinstance(d, tensor_shape.Dimension) and d.value is None
@tf_export("__operators__.getitem", v1=[])
@dispatch.add_dispatch_support
def _slice_helper(tensor, slice_spec, var=None):
"""Overload for Tensor.__getitem__.
@ -914,6 +916,15 @@ def _slice_helper(tensor, slice_spec, var=None):
- An implicit ellipsis is placed at the end of the `slice_spec`
- NumPy advanced indexing is currently not supported.
Purpose in the API:
This method is exposed in TensorFlow's API so that library developers
can register dispatching for `Tensor.__getitem__` to allow it to handle
custom composite tensors & other custom objects.
The API symbol is not intended to be called by users directly and does
appear in TensorFlow's generated documentation.
Args:
tensor: An ops.Tensor object.
slice_spec: The arguments to Tensor.__getitem__.

View File

@ -1416,8 +1416,28 @@ truncatemod = gen_math_ops.truncate_mod
floormod = gen_math_ops.floor_mod
@tf_export("__operators__.add", v1=[])
@dispatch.add_dispatch_support
def _add_dispatch(x, y, name=None):
"""Dispatches to add for strings and add_v2 for all other types."""
"""The operation invoked by the `Tensor.__add__` operator.
Purpose in the API:
This method is exposed in TensorFlow's API so that library developers
can register dispatching for `Tensor.__add__` to allow it to handle
custom composite tensors & other custom objects.
The API symbol is not intended to be called by users directly and does
appear in TensorFlow's generated documentation.
Args:
x: The left-hand side of the `+` operator.
y: The right-hand side of the `+` operator.
name: an optional name for the operation.
Returns:
The result of the elementwise `+` operation.
"""
if not isinstance(y, ops.Tensor) and not isinstance(
y, sparse_tensor.SparseTensor):
y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y")
@ -1630,8 +1650,33 @@ def not_equal(x, y, name=None):
return gen_math_ops.not_equal(x, y, name=name)
@tf_export("__operators__.eq", v1=[])
@dispatch.add_dispatch_support
def tensor_equals(self, other):
"""Compares two tensors element-wise for equality."""
"""The operation invoked by the `Tensor.__eq__` operator.
Compares two tensors element-wise for equality if they are
broadcast-compatible; or returns False if they are not broadcast-compatible.
(Note that this behavior differs from `tf.math.equal`, which raises an
exception if the two tensors are not broadcast-compatible.)
Purpose in the API:
This method is exposed in TensorFlow's API so that library developers
can register dispatching for `Tensor.__eq__` to allow it to handle
custom composite tensors & other custom objects.
The API symbol is not intended to be called by users directly and does
appear in TensorFlow's generated documentation.
Args:
self: The left-hand side of the `==` operator.
other: The right-hand side of the `==` operator.
Returns:
The result of the elementwise `==` operation, or `False` if the arguments
are not broadcast-compatible.
"""
if other is None:
return False
g = getattr(self, "graph", None)
@ -1643,8 +1688,33 @@ def tensor_equals(self, other):
return self is other
@tf_export("__operators__.ne", v1=[])
@dispatch.add_dispatch_support
def tensor_not_equals(self, other):
"""Compares two tensors element-wise for equality."""
"""The operation invoked by the `Tensor.__ne__` operator.
Compares two tensors element-wise for inequality if they are
broadcast-compatible; or returns True if they are not broadcast-compatible.
(Note that this behavior differs from `tf.math.not_equal`, which raises an
exception if the two tensors are not broadcast-compatible.)
Purpose in the API:
This method is exposed in TensorFlow's API so that library developers
can register dispatching for `Tensor.__ne__` to allow it to handle
custom composite tensors & other custom objects.
The API symbol is not intended to be called by users directly and does
appear in TensorFlow's generated documentation.
Args:
self: The left-hand side of the `!=` operator.
other: The right-hand side of the `!=` operator.
Returns:
The result of the elementwise `!=` operation, or `True` if the arguments
are not broadcast-compatible.
"""
if other is None:
return True
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():

View File

@ -4,6 +4,7 @@
TENSORFLOW_API_INIT_FILES = [
# BEGIN GENERATED FILES
"__init__.py",
"__operators__/__init__.py",
"audio/__init__.py",
"autograph/__init__.py",
"autograph/experimental/__init__.py",

View File

@ -0,0 +1,19 @@
path: "tensorflow.__operators__"
tf_module {
member_method {
name: "add"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "eq"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "getitem"
argspec: "args=[\'tensor\', \'slice_spec\', \'var\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "ne"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -191,6 +191,11 @@ def build_docs(output_dir, code_url_prefix, search_hints=True):
_hide_layer_and_module_methods()
try:
doc_controls.do_not_generate_docs(tf.__operators__)
except AttributeError:
pass
try:
doc_controls.do_not_generate_docs(tf.tools)
except AttributeError: