diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 53f6a2b0492..8575cdf3da5 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -545,7 +545,31 @@ def py_func_common(func, inp, Tout, stateful=True, name=None): `tf.compat.v1.py_func()` and you must pin the created operation to a device in that server (e.g. using `with tf.device():`). + + Note: It produces tensors of unknown shape and rank as shape inference + does not work on arbitrary Python code. + If you need the shape, you need to set it based on statically + available information. + + E.g. + ```python + import tensorflow as tf + import numpy as np + def make_synthetic_data(i): + return np.cast[np.uint8](i) * np.ones([20,256,256,3], + dtype=np.float32) / 10. + + def preprocess_fn(i): + ones = tf.py_function(make_synthetic_data,[i],tf.float32) + ones.set_shape(tf.TensorShape([None, None, None, None])) + ones = tf.image.resize(ones, [224,224]) + return ones + + ds = tf.data.Dataset.range(10) + ds = ds.map(preprocess_fn) + ``` + Args: func: A Python function, which accepts `ndarray` objects as arguments and returns a list of `ndarray` objects (or a single `ndarray`). This function