Merge pull request #44994 from bhack:patch-11

PiperOrigin-RevId: 343825425
Change-Id: I03dde42bb6473d3b6c7ea49116c3d2e00f2af703
This commit is contained in:
TensorFlower Gardener 2020-11-23 04:31:01 -08:00
commit 719b255cbf

View File

@ -545,7 +545,31 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
`tf.compat.v1.py_func()` and you must pin the created operation to a device
in that
server (e.g. using `with tf.device():`).
Note: It produces tensors of unknown shape and rank as shape inference
does not work on arbitrary Python code.
If you need the shape, you need to set it based on statically
available information.
E.g.
```python
import tensorflow as tf
import numpy as np
def make_synthetic_data(i):
return np.cast[np.uint8](i) * np.ones([20,256,256,3],
dtype=np.float32) / 10.
def preprocess_fn(i):
ones = tf.py_function(make_synthetic_data,[i],tf.float32)
ones.set_shape(tf.TensorShape([None, None, None, None]))
ones = tf.image.resize(ones, [224,224])
return ones
ds = tf.data.Dataset.range(10)
ds = ds.map(preprocess_fn)
```
Args:
func: A Python function, which accepts `ndarray` objects as arguments and
returns a list of `ndarray` objects (or a single `ndarray`). This function