Merge pull request #44994 from bhack:patch-11
PiperOrigin-RevId: 343825425 Change-Id: I03dde42bb6473d3b6c7ea49116c3d2e00f2af703
This commit is contained in:
commit
719b255cbf
@ -545,7 +545,31 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
|
|||||||
`tf.compat.v1.py_func()` and you must pin the created operation to a device
|
`tf.compat.v1.py_func()` and you must pin the created operation to a device
|
||||||
in that
|
in that
|
||||||
server (e.g. using `with tf.device():`).
|
server (e.g. using `with tf.device():`).
|
||||||
|
|
||||||
|
Note: It produces tensors of unknown shape and rank as shape inference
|
||||||
|
does not work on arbitrary Python code.
|
||||||
|
If you need the shape, you need to set it based on statically
|
||||||
|
available information.
|
||||||
|
|
||||||
|
E.g.
|
||||||
|
```python
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def make_synthetic_data(i):
|
||||||
|
return np.cast[np.uint8](i) * np.ones([20,256,256,3],
|
||||||
|
dtype=np.float32) / 10.
|
||||||
|
|
||||||
|
def preprocess_fn(i):
|
||||||
|
ones = tf.py_function(make_synthetic_data,[i],tf.float32)
|
||||||
|
ones.set_shape(tf.TensorShape([None, None, None, None]))
|
||||||
|
ones = tf.image.resize(ones, [224,224])
|
||||||
|
return ones
|
||||||
|
|
||||||
|
ds = tf.data.Dataset.range(10)
|
||||||
|
ds = ds.map(preprocess_fn)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: A Python function, which accepts `ndarray` objects as arguments and
|
func: A Python function, which accepts `ndarray` objects as arguments and
|
||||||
returns a list of `ndarray` objects (or a single `ndarray`). This function
|
returns a list of `ndarray` objects (or a single `ndarray`). This function
|
||||||
|
Loading…
x
Reference in New Issue
Block a user