Merge pull request #41981 from lithuak:iss35342-take-two

PiperOrigin-RevId: 329749527
Change-Id: I4e29264e4f7eb97f47856e8eb527a66336a78c1f
This commit is contained in:
TensorFlower Gardener 2020-09-02 11:19:25 -07:00
commit a2c542a0d8
21 changed files with 263 additions and 130 deletions

View File

@ -98,6 +98,9 @@
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overriden.
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
with a new `output_signature` argument, which allows `from_generator` to
produce any type describable by a `tf.TypeSpec`.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op
@ -183,13 +186,11 @@
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* `tf.nn`:
* `tf.nn.max_pool2d` now supports explicit padding.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more context.
<ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors

View File

@ -61,6 +61,7 @@ class UniqueTest(test_base.DatasetTestBase, parameterized.TestCase):
([], []),
([1], [1]),
([1, 1, 1, 1, 1, 1, 1], [1]),
([1, 1, 1, 1, 0], [1, 0]),
([1, 2, 3, 4], [1, 2, 3, 4]),
([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),

View File

@ -28,7 +28,12 @@ from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
@ -259,7 +264,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
with self.assertRaisesOpError("The expected type was int64"):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(get_next())
self.assertAllEqual([7, 8, 9], self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
@ -279,7 +284,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(get_next())
self.assertAllEqual([11, 12, 13], self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
@ -300,11 +305,9 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual((1, 2), self.evaluate(get_next()))
self.assertEqual((3, 4), self.evaluate(get_next()))
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(get_next())
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(get_next())
self.assertEqual((9, 10), self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
@ -423,8 +426,12 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
stateful=True)
dummy = constant_op.constant(37)
dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x,
finalize_fn).take(2)
dataset = dataset_ops._GeneratorDataset(
dummy, lambda x: x, lambda x: x, finalize_fn,
tensor_spec.TensorSpec((), dtypes.int32))
dataset = dataset.take(2)
get_next = self.getNext(dataset)
self.assertAllEqual(37, self.evaluate(get_next()))
@ -446,6 +453,44 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual([20], self.evaluate(get_next()))
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorRaggedTensor(self):
def generator():
yield ragged_factory_ops.constant([[1, 2], [3]])
dataset = dataset_ops.Dataset.from_generator(
generator,
output_signature=ragged_tensor.RaggedTensorSpec(
shape=(2, None), dtype=dtypes.int32))
get_next = self.getNext(dataset)
ret = get_next()
self.assertIsInstance(ret, ragged_tensor.RaggedTensor)
self.assertAllEqual([[1, 2], [3]], ret)
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorSparseTensor(self):
def generator():
yield sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]],
values=constant_op.constant([1, 2], dtype=dtypes.int64),
dense_shape=[3, 4])
dataset = dataset_ops.Dataset.from_generator(
generator,
output_signature=sparse_tensor.SparseTensorSpec([3, 4], dtypes.int64))
get_next = self.getNext(dataset)
ret = get_next()
self.assertIsInstance(ret, sparse_tensor.SparseTensor)
self.assertAllEqual([[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]],
sparse_ops.sparse_tensor_to_dense(ret))
@combinations.generate(test_base.default_test_combinations())
def testTypeIsListError(self):

View File

@ -946,7 +946,9 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
@def_function.function
def fn():
dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn)
output_signature = tensor_spec.TensorSpec((), dtypes.int64)
dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn,
output_signature)
iterator = iter(dataset)
next(iterator)

View File

@ -725,27 +725,46 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
del self._iterators[iterator_id]
@staticmethod
def from_generator(generator, output_types, output_shapes=None, args=None):
@deprecation.deprecated_args(None, "Use output_signature instead",
"output_types", "output_shapes")
def from_generator(generator,
output_types=None,
output_shapes=None,
args=None,
output_signature=None):
"""Creates a `Dataset` whose elements are generated by `generator`.
The `generator` argument must be a callable object that returns
an object that supports the `iter()` protocol (e.g. a generator function).
The elements generated by `generator` must be compatible with the given
`output_types` and (optional) `output_shapes` arguments.
>>> import itertools
>>>
The elements generated by `generator` must be compatible with either the
given `output_signature` argument or with the given `output_types` and
(optionally) `output_shapes` arguments, whichiver was specified.
The recommended way to call `from_generator` is to use the
`output_signature` argument. In this case the output will be assumed to
consist of objects with the classes, shapes and types defined by
`tf.TypeSpec` objects from `output_signature` argument:
>>> def gen():
... for i in itertools.count(1):
... yield (i, [1] * i)
... ragged_tensor = tf.ragged.constant([[1, 2], [3]])
... yield 42, ragged_tensor
>>>
>>> dataset = tf.data.Dataset.from_generator(
... gen,
... (tf.int64, tf.int64),
... (tf.TensorShape([]), tf.TensorShape([None])))
... output_signature=(
... tf.TensorSpec(shape=(), dtype=tf.int32),
... tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
>>>
>>> list(dataset.take(3).as_numpy_iterator())
[(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))]
>>> list(dataset.take(1))
[(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
<tf.RaggedTensor [[1, 2], [3]]>)]
There is also a deprecated way to call `from_generator` by either with
`output_types` argument alone or together with `output_shapes` argument.
In this case the output of the function will be assumed to consist of
`tf.Tensor` objects with with the types defined by `output_types` and with
the shapes which are either unknown or defined by `output_shapes`.
Note: The current implementation of `Dataset.from_generator()` uses
`tf.numpy_function` and inherits the same constraints. In particular, it
@ -769,31 +788,56 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
`iter()` protocol. If `args` is not specified, `generator` must take no
arguments; otherwise it must take as many arguments as there are values
in `args`.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element yielded by `generator`.
output_types: (Optional.) A nested structure of `tf.DType` objects
corresponding to each component of an element yielded by `generator`.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
corresponding to each component of an element yielded by `generator`.
args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
and passed to `generator` as NumPy-array arguments.
output_signature: (Optional.) A nested structure of `tf.TypeSpec` objects
corresponding to each component of an element yielded by `generator`.
Returns:
Dataset: A `Dataset`.
"""
if not callable(generator):
raise TypeError("`generator` must be callable.")
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
if output_signature is not None:
if output_types is not None:
raise TypeError("`output_types` can not be used together with "
"`output_signature`")
if output_shapes is not None:
raise TypeError("`output_shapes` can not be used together with "
"`output_signature`")
if not all(
isinstance(_, type_spec.TypeSpec)
for _ in nest.flatten(output_signature)):
raise TypeError("All the elements of `output_signature` must be "
"`tf.TypeSpec` objects.")
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
if output_types is None:
raise TypeError("Either `output_signature` or `output_types` must "
"be specified")
if output_signature is None:
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(output_types,
tensor_shape.as_shape,
output_shapes)
output_signature = nest.map_structure_up_to(output_types,
tensor_spec.TensorSpec,
output_shapes, output_types)
if args is None:
args = ()
else:
args = tuple(ops.convert_n_to_tensor(args, name="args"))
flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)]
flattened_shapes = nest.flatten(output_shapes)
flat_output_types = structure.get_flat_tensor_types(output_signature)
generator_state = DatasetV2._GeneratorState(generator)
@ -831,56 +875,33 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
"""A `py_func` that will be called to invoke the iterator."""
# `next()` raises `StopIteration` when there are no more
# elements remaining to be generated.
values = next(generator_state.get_iterator(iterator_id))
values = next(generator_state.get_iterator(iterator_id.numpy()))
# Use the same _convert function from the py_func() implementation to
# convert the returned values to arrays early, so that we can inspect
# their values.
try:
flattened_values = nest.flatten_up_to(output_types, values)
values = structure.normalize_element(values, output_signature)
except (TypeError, ValueError):
six.reraise(TypeError, TypeError(
"`generator` yielded an element that did not match the expected "
"structure. The expected structure was %s, but the yielded "
"element was %s." % (output_types, values)), sys.exc_info()[2])
ret_arrays = []
for ret, dtype in zip(flattened_values, flattened_types):
try:
ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
ret, dtype=dtype.as_numpy_dtype))
except (TypeError, ValueError):
six.reraise(TypeError, TypeError(
"`generator` yielded an element that could not be converted to "
"the expected type. The expected type was %s, but the yielded "
"element was %s." % (dtype.name, ret)), sys.exc_info()[2])
six.reraise(
TypeError,
TypeError(
"`generator` yielded an element that did not match the "
"expected structure. The expected structure was %s, but the "
"yielded element was %s." % (output_signature, values)),
sys.exc_info()[2])
# Additional type and shape checking to ensure that the components
# of the generated element match the `output_types` and `output_shapes`
# arguments.
for (ret_array, expected_dtype, expected_shape) in zip(
ret_arrays, flattened_types, flattened_shapes):
if ret_array.dtype != expected_dtype.as_numpy_dtype:
raise TypeError(
"`generator` yielded an element of type %s where an element "
"of type %s was expected." % (ret_array.dtype,
expected_dtype.as_numpy_dtype))
if not expected_shape.is_compatible_with(ret_array.shape):
raise ValueError(
"`generator` yielded an element of shape %s where an element "
"of shape %s was expected." % (ret_array.shape, expected_shape))
values_spec = structure.type_spec_from_value(values)
return ret_arrays
if not structure.are_compatible(values_spec, output_signature):
raise TypeError(
"`generator` yielded an element of %s where an element "
"of %s was expected." % (values_spec, output_signature))
flat_values = script_ops.numpy_function(generator_py_func,
[iterator_id_t], flattened_types)
return structure.to_tensor_list(output_signature, values)
# The `py_func()` op drops the inferred shapes, so we add them back in
# here.
if output_shapes is not None:
for ret_t, shape in zip(flat_values, flattened_shapes):
ret_t.set_shape(shape)
return nest.pack_sequence_as(output_types, flat_values)
return script_ops._eager_py_func( # pylint: disable=protected-access
generator_py_func,
inp=[iterator_id_t],
Tout=flat_output_types,
use_tape_cache=False)
def finalize_fn(iterator_id_t):
"""Releases host-side state for the iterator with ID `iterator_id_t`."""
@ -906,7 +927,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
# given ID, and raises StopIteration when that iterator contains no
# more elements.
return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
finalize_fn)
finalize_fn, output_signature)
# A single-element dataset that, each time it is evaluated, contains a
# freshly-generated and unique (for the returned dataset) int64
@ -2148,21 +2169,8 @@ name=None))
`cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset
contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if
the analysis fails to determine the number of elements in the dataset.
`cardinality` only reports known cardinality (finite or infinite), if it can
be inferred statically. In particular, the implementation does not iterate
through the dataset or evaluate user-defined functions. As a consequence,
the statically inferred cardinality may often be unknown. For example, if
the dataset reads from file(s), the cardinality will be unknown. The
cardinality will also be unknown if the dataset contains user-defined
functions which could affect the cardinality (such as the functions in
`filter`, `flat_map`, `interleave`, or `from_generator`).
When constructing a dataset, you can apply the
`tf.data.experimental.assert_cardinality` transformation to inform the
dataset of its expected cardinality, so that `cardinality` can produce a
known cardinality.
the analysis fails to determine the number of elements in the dataset
(e.g. when the dataset source is a file).
>>> dataset = tf.data.Dataset.range(42)
>>> print(dataset.cardinality().numpy())
@ -2171,13 +2179,10 @@ name=None))
>>> cardinality = dataset.cardinality()
>>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
True
>>> dataset = dataset.filter(lambda x: False)
>>> dataset = dataset.filter(lambda x: True)
>>> cardinality = dataset.cardinality()
>>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
True
>>> dataset = dataset.apply(tf.data.experimental.assert_cardinality(0))
>>> print(dataset.cardinality().numpy())
0
Returns:
A scalar `tf.int64` `Tensor` representing the cardinality of the dataset.
@ -2458,9 +2463,14 @@ class DatasetV1(DatasetV2):
@staticmethod
@functools.wraps(DatasetV2.from_generator)
def from_generator(generator, output_types, output_shapes=None, args=None):
return DatasetV1Adapter(DatasetV2.from_generator(
generator, output_types, output_shapes, args))
def from_generator(generator,
output_types=None,
output_shapes=None,
args=None,
output_signature=None):
return DatasetV1Adapter(
DatasetV2.from_generator(generator, output_types, output_shapes, args,
output_signature))
@staticmethod
@functools.wraps(DatasetV2.range)
@ -3469,7 +3479,8 @@ class StructuredFunctionWrapper(object):
class _GeneratorDataset(DatasetSource):
"""A `Dataset` that generates elements by invoking a function."""
def __init__(self, init_args, init_func, next_func, finalize_func):
def __init__(self, init_args, init_func, next_func, finalize_func,
output_signature):
"""Constructs a `_GeneratorDataset`.
Args:
@ -3483,6 +3494,8 @@ class _GeneratorDataset(DatasetSource):
finalize_func: A TensorFlow function that will be called on the result of
`init_func` immediately before a C++ iterator over this dataset is
destroyed. The return value is ignored.
output_signature: A nested structure of `tf.TypeSpec` objects describing
the output of `next_func`.
"""
self._init_args = init_args
@ -3502,6 +3515,9 @@ class _GeneratorDataset(DatasetSource):
finalize_func,
self._transformation_name(),
input_structure=self._init_func.output_structure)
self._output_signature = output_signature
variant_tensor = gen_dataset_ops.generator_dataset(
structure.to_tensor_list(self._init_structure, self._init_args) +
self._init_func.function.captured_inputs,
@ -3515,7 +3531,7 @@ class _GeneratorDataset(DatasetSource):
@property
def element_spec(self):
return self._next_func.output_structure
return self._output_signature
def _transformation_name(self):
return "Dataset.from_generator()"

View File

@ -67,7 +67,7 @@ def _RaggedTensorStructure(dtype, shape, ragged_rank):
# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once
# it is a subclass of `CompositeTensor`.
def normalize_element(element):
def normalize_element(element, element_signature=None):
"""Normalizes a nested structure of element components.
* Components matching `SparseTensorSpec` are converted to `SparseTensor`.
@ -78,19 +78,32 @@ def normalize_element(element):
Args:
element: A nested structure of individual components.
element_signature: (Optional.) A nested structure of `tf.DType` objects
corresponding to each component of `element`. If specified, it will be
used to set the exact type of output tensor when converting input
components which are not tensors themselves (e.g. numpy arrays, native
python types, etc.)
Returns:
A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`,
or `TensorArray` objects.
"""
components = nest.flatten(element)
normalized_components = []
if element_signature is None:
components = nest.flatten(element)
flattened_signature = [None] * len(components)
pack_as = element
else:
flattened_signature = nest.flatten(element_signature)
components = nest.flatten_up_to(element_signature, element)
pack_as = element_signature
with ops.name_scope("normalize_element"):
# Imported here to avoid circular dependency.
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
for i, t in enumerate(components):
for i, (t, spec) in enumerate(zip(components, flattened_signature)):
try:
spec = type_spec_from_value(t, use_fallback=False)
if spec is None:
spec = type_spec_from_value(t, use_fallback=False)
except TypeError:
# TypeError indicates it was not possible to compute a `TypeSpec` for
# the value. As a fallback try converting the value to a tensor.
@ -111,9 +124,10 @@ def normalize_element(element):
elif isinstance(t, composite_tensor.CompositeTensor):
normalized_components.append(t)
else:
dtype = getattr(spec, "dtype", None)
normalized_components.append(
ops.convert_to_tensor(t, name="component_%d" % i))
return nest.pack_sequence_as(element, normalized_components)
ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype))
return nest.pack_sequence_as(pack_as, normalized_components)
def convert_legacy_structure(output_types, output_shapes, output_classes):

View File

@ -71,7 +71,7 @@ def _maybe_copy_to_context_device(tensor, device_name):
class EagerFunc(object):
"""A wrapper for a function owned by an EagerPyFunc."""
def __init__(self, func, Tout, is_grad_func):
def __init__(self, func, Tout, is_grad_func, use_tape_cache=True):
"""Constructs an EagerFunc.
Args:
@ -80,10 +80,14 @@ class EagerFunc(object):
None.
is_grad_func: Whether this EagerFunc is the gradient of another
EagerPyFunc.
use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`.
For additional information, see description of `_eager_py_func`.
This parameter should be removed once the #35084 issue is fixed.
"""
self._func = func
self._out_dtypes = Tout
self._is_grad_func = is_grad_func
self._use_tape_cache = use_tape_cache
def _convert(self, value, dtype):
"""Converts `value` to a tensor of type `dtype`, with error checking.
@ -147,7 +151,8 @@ class EagerFunc(object):
else:
outputs = _maybe_copy_to_context_device(
self._convert(ret, dtype=self._out_dtypes[0]), device_name)
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
if self._use_tape_cache:
tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
return outputs
@ -277,7 +282,8 @@ def _internal_py_func(func,
stateful=None,
eager=False,
is_grad_func=False,
name=None):
name=None,
use_tape_cache=True):
"""See documentation for py_func and eager_py_func."""
if not callable(func):
raise ValueError("Expected func to be callable, got func of type {}".format(
@ -293,7 +299,7 @@ def _internal_py_func(func,
Tout = [Tout]
if eager:
func = EagerFunc(func, Tout, is_grad_func)
func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache)
# Tying the registered function's lifetime with the current default graph is
# not reliable. For example, Estimator-based binaries may switch graphs in
@ -370,6 +376,58 @@ def _EagerPyFuncGrad(op, *dy):
is_grad_func=True)
def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True):
"""Wraps a python function into a TensorFlow op that executes it eagerly.
This function is the internal implementation for `eager_py_func`, see the
`eager_py_func` docstring for the full description.
Note: this function as a layer of indirection was added with one
specific purpose: as a workaround for github issue #35084.
It does all the same as `eager_py_func` used to do with one difference:
it can be used to instruct underlying EagerFunc not to use `tape_cache`
to avoid memory leak. When the issue #35084 is fixed - this function should
be removed, its body should be moved back to become the body of
`eager_py_func` and all the call sites should be reverted to
using `eager_py_func` without `use_tape_cache` argument of any value.
Args:
func: A Python function which accepts a list of `Tensor` objects having
element types that match the corresponding `tf.Tensor` objects in `inp`
and returns a list of `Tensor` objects (or a single `Tensor`, or `None`)
having element types that match the corresponding values in `Tout`.
inp: A list of `Tensor` objects.
Tout: A list or tuple of tensorflow data types or a single tensorflow data
type if there is only one, indicating what `func` returns; an empty list
if no value is returned (i.e., if the return value is `None`).
name: A name for the operation (optional).
use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`.
For additional information, see description of `_eager_py_func`.
This parameter should be removed once the #35084 issue is fixed.
Returns:
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
if `func` returns None.
"""
if ops.executing_eagerly_outside_functions():
with ops.device(context.context().host_address_space()):
return _internal_py_func(
func=func,
inp=inp,
Tout=Tout,
eager=True,
name=name,
use_tape_cache=use_tape_cache)
return _internal_py_func(
func=func,
inp=inp,
Tout=Tout,
eager=True,
name=name,
use_tape_cache=use_tape_cache)
@tf_export("py_function")
@dispatch.add_dispatch_support
def eager_py_func(func, inp, Tout, name=None):
@ -451,12 +509,8 @@ def eager_py_func(func, inp, Tout, name=None):
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
if `func` returns None.
"""
if ops.executing_eagerly_outside_functions():
with ops.device(context.context().host_address_space()):
return _internal_py_func(
func=func, inp=inp, Tout=Tout, eager=True, name=name)
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
return _eager_py_func(
func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True)
def py_func_common(func, inp, Tout, stateful=True, name=None):

View File

@ -65,7 +65,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"

View File

@ -67,7 +67,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"

View File

@ -67,7 +67,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"

View File

@ -67,7 +67,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"

View File

@ -67,7 +67,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"

View File

@ -67,7 +67,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"

View File

@ -67,7 +67,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"

View File

@ -48,7 +48,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_tensor_slices"

View File

@ -50,7 +50,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_tensor_slices"

View File

@ -49,7 +49,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_tensor_slices"

View File

@ -50,7 +50,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_tensor_slices"

View File

@ -50,7 +50,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_tensor_slices"

View File

@ -50,7 +50,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_tensor_slices"

View File

@ -50,7 +50,7 @@ tf_class {
}
member_method {
name: "from_generator"
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\', \'output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_tensor_slices"