Merge pull request #44994 from bhack:patch-11
PiperOrigin-RevId: 343825425 Change-Id: I03dde42bb6473d3b6c7ea49116c3d2e00f2af703
This commit is contained in:
commit
719b255cbf
@ -545,7 +545,31 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
|
||||
`tf.compat.v1.py_func()` and you must pin the created operation to a device
|
||||
in that
|
||||
server (e.g. using `with tf.device():`).
|
||||
|
||||
Note: It produces tensors of unknown shape and rank as shape inference
|
||||
does not work on arbitrary Python code.
|
||||
If you need the shape, you need to set it based on statically
|
||||
available information.
|
||||
|
||||
E.g.
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
def make_synthetic_data(i):
|
||||
return np.cast[np.uint8](i) * np.ones([20,256,256,3],
|
||||
dtype=np.float32) / 10.
|
||||
|
||||
def preprocess_fn(i):
|
||||
ones = tf.py_function(make_synthetic_data,[i],tf.float32)
|
||||
ones.set_shape(tf.TensorShape([None, None, None, None]))
|
||||
ones = tf.image.resize(ones, [224,224])
|
||||
return ones
|
||||
|
||||
ds = tf.data.Dataset.range(10)
|
||||
ds = ds.map(preprocess_fn)
|
||||
```
|
||||
|
||||
Args:
|
||||
func: A Python function, which accepts `ndarray` objects as arguments and
|
||||
returns a list of `ndarray` objects (or a single `ndarray`). This function
|
||||
|
Loading…
x
Reference in New Issue
Block a user