Export and add dispatch to overloaded tensor operators (that were not already exported under different api symbols) under the hidden namespace tf.__operators__
.
This will make these built-in operators more amenable to dispatching for library developers. This includes: tf.__operators__.add tf.__operators__.ne tf.__operators__.eq tf.__operators__.getitem PiperOrigin-RevId: 315998480 Change-Id: Icf61e24a2c037eaf2c4d170967eb2b8ac18f5961
This commit is contained in:
parent
e3aa45345b
commit
7933b0e5e3
@ -212,6 +212,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
func() # Warmup.
|
||||
self._run(func, 30000)
|
||||
|
||||
def _benchmark_add_operator_overload(self, a, b):
|
||||
def func():
|
||||
return memoryview(a + b)
|
||||
|
||||
with ops.device("GPU:0" if context.num_gpus() else "CPU:0"):
|
||||
for _ in range(1000):
|
||||
func() # Warmup.
|
||||
self._run(func, 30000)
|
||||
|
||||
def benchmark_add_float_scalars(self):
|
||||
self._benchmark_add(42.0, 24.0)
|
||||
|
||||
@ -223,6 +232,11 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
tensor_b = constant_op.constant(24.0)
|
||||
self._benchmark_add(tensor_a, tensor_b)
|
||||
|
||||
def benchmark_add_float_scalar_tensor_overloaded_operator(self):
|
||||
tensor_a = constant_op.constant(42.0)
|
||||
tensor_b = constant_op.constant(24.0)
|
||||
self._benchmark_add_operator_overload(tensor_a, tensor_b)
|
||||
|
||||
def benchmark_add_int32_scalar_tensor(self):
|
||||
tensor_a = constant_op.constant(42)
|
||||
tensor_b = constant_op.constant(24)
|
||||
|
@ -869,6 +869,8 @@ def _is_undefined_dimension(d):
|
||||
return isinstance(d, tensor_shape.Dimension) and d.value is None
|
||||
|
||||
|
||||
@tf_export("__operators__.getitem", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def _slice_helper(tensor, slice_spec, var=None):
|
||||
"""Overload for Tensor.__getitem__.
|
||||
|
||||
@ -914,6 +916,15 @@ def _slice_helper(tensor, slice_spec, var=None):
|
||||
- An implicit ellipsis is placed at the end of the `slice_spec`
|
||||
- NumPy advanced indexing is currently not supported.
|
||||
|
||||
Purpose in the API:
|
||||
|
||||
This method is exposed in TensorFlow's API so that library developers
|
||||
can register dispatching for `Tensor.__getitem__` to allow it to handle
|
||||
custom composite tensors & other custom objects.
|
||||
|
||||
The API symbol is not intended to be called by users directly and does
|
||||
appear in TensorFlow's generated documentation.
|
||||
|
||||
Args:
|
||||
tensor: An ops.Tensor object.
|
||||
slice_spec: The arguments to Tensor.__getitem__.
|
||||
|
@ -1416,8 +1416,28 @@ truncatemod = gen_math_ops.truncate_mod
|
||||
floormod = gen_math_ops.floor_mod
|
||||
|
||||
|
||||
@tf_export("__operators__.add", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def _add_dispatch(x, y, name=None):
|
||||
"""Dispatches to add for strings and add_v2 for all other types."""
|
||||
"""The operation invoked by the `Tensor.__add__` operator.
|
||||
|
||||
Purpose in the API:
|
||||
|
||||
This method is exposed in TensorFlow's API so that library developers
|
||||
can register dispatching for `Tensor.__add__` to allow it to handle
|
||||
custom composite tensors & other custom objects.
|
||||
|
||||
The API symbol is not intended to be called by users directly and does
|
||||
appear in TensorFlow's generated documentation.
|
||||
|
||||
Args:
|
||||
x: The left-hand side of the `+` operator.
|
||||
y: The right-hand side of the `+` operator.
|
||||
name: an optional name for the operation.
|
||||
|
||||
Returns:
|
||||
The result of the elementwise `+` operation.
|
||||
"""
|
||||
if not isinstance(y, ops.Tensor) and not isinstance(
|
||||
y, sparse_tensor.SparseTensor):
|
||||
y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y")
|
||||
@ -1630,8 +1650,33 @@ def not_equal(x, y, name=None):
|
||||
return gen_math_ops.not_equal(x, y, name=name)
|
||||
|
||||
|
||||
@tf_export("__operators__.eq", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def tensor_equals(self, other):
|
||||
"""Compares two tensors element-wise for equality."""
|
||||
"""The operation invoked by the `Tensor.__eq__` operator.
|
||||
|
||||
Compares two tensors element-wise for equality if they are
|
||||
broadcast-compatible; or returns False if they are not broadcast-compatible.
|
||||
(Note that this behavior differs from `tf.math.equal`, which raises an
|
||||
exception if the two tensors are not broadcast-compatible.)
|
||||
|
||||
Purpose in the API:
|
||||
|
||||
This method is exposed in TensorFlow's API so that library developers
|
||||
can register dispatching for `Tensor.__eq__` to allow it to handle
|
||||
custom composite tensors & other custom objects.
|
||||
|
||||
The API symbol is not intended to be called by users directly and does
|
||||
appear in TensorFlow's generated documentation.
|
||||
|
||||
Args:
|
||||
self: The left-hand side of the `==` operator.
|
||||
other: The right-hand side of the `==` operator.
|
||||
|
||||
Returns:
|
||||
The result of the elementwise `==` operation, or `False` if the arguments
|
||||
are not broadcast-compatible.
|
||||
"""
|
||||
if other is None:
|
||||
return False
|
||||
g = getattr(self, "graph", None)
|
||||
@ -1643,8 +1688,33 @@ def tensor_equals(self, other):
|
||||
return self is other
|
||||
|
||||
|
||||
@tf_export("__operators__.ne", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def tensor_not_equals(self, other):
|
||||
"""Compares two tensors element-wise for equality."""
|
||||
"""The operation invoked by the `Tensor.__ne__` operator.
|
||||
|
||||
Compares two tensors element-wise for inequality if they are
|
||||
broadcast-compatible; or returns True if they are not broadcast-compatible.
|
||||
(Note that this behavior differs from `tf.math.not_equal`, which raises an
|
||||
exception if the two tensors are not broadcast-compatible.)
|
||||
|
||||
Purpose in the API:
|
||||
|
||||
This method is exposed in TensorFlow's API so that library developers
|
||||
can register dispatching for `Tensor.__ne__` to allow it to handle
|
||||
custom composite tensors & other custom objects.
|
||||
|
||||
The API symbol is not intended to be called by users directly and does
|
||||
appear in TensorFlow's generated documentation.
|
||||
|
||||
Args:
|
||||
self: The left-hand side of the `!=` operator.
|
||||
other: The right-hand side of the `!=` operator.
|
||||
|
||||
Returns:
|
||||
The result of the elementwise `!=` operation, or `True` if the arguments
|
||||
are not broadcast-compatible.
|
||||
"""
|
||||
if other is None:
|
||||
return True
|
||||
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
|
||||
|
@ -4,6 +4,7 @@
|
||||
TENSORFLOW_API_INIT_FILES = [
|
||||
# BEGIN GENERATED FILES
|
||||
"__init__.py",
|
||||
"__operators__/__init__.py",
|
||||
"audio/__init__.py",
|
||||
"autograph/__init__.py",
|
||||
"autograph/experimental/__init__.py",
|
||||
|
@ -0,0 +1,19 @@
|
||||
path: "tensorflow.__operators__"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "add"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "eq"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "getitem"
|
||||
argspec: "args=[\'tensor\', \'slice_spec\', \'var\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ne"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -191,6 +191,11 @@ def build_docs(output_dir, code_url_prefix, search_hints=True):
|
||||
|
||||
_hide_layer_and_module_methods()
|
||||
|
||||
try:
|
||||
doc_controls.do_not_generate_docs(tf.__operators__)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
doc_controls.do_not_generate_docs(tf.tools)
|
||||
except AttributeError:
|
||||
|
Loading…
Reference in New Issue
Block a user