diff --git a/tensorflow/python/framework/type_spec.py b/tensorflow/python/framework/type_spec.py index e6e921e6184..4bf2ad791d7 100644 --- a/tensorflow/python/framework/type_spec.py +++ b/tensorflow/python/framework/type_spec.py @@ -503,11 +503,29 @@ class BatchableTypeSpec(TypeSpec): return tensor_list +@tf_export("type_spec_from_value") def type_spec_from_value(value): - """Returns a `TypeSpec` that represents the given `value`. + """Returns a `tf.TypeSpec` that represents the given `value`. + + Examples: + + >>> tf.type_spec_from_value(tf.constant([1, 2, 3])) + TensorSpec(shape=(3,), dtype=tf.int32, name=None) + >>> tf.type_spec_from_value(np.array([4.0, 5.0], np.float64)) + TensorSpec(shape=(2,), dtype=tf.float64, name=None) + >>> tf.type_spec_from_value(tf.ragged.constant([[1, 2], [3, 4, 5]])) + RaggedTensorSpec(TensorShape([2, None]), tf.int32, 1, tf.int64) + + >>> example_input = tf.ragged.constant([[1, 2], [3]]) + >>> @tf.function(input_signature=[tf.type_spec_from_value(example_input)]) + ... def f(x): + ... return tf.reduce_sum(x, axis=1) Args: value: A value that can be accepted or returned by TensorFlow APIs. + Accepted types for `value` include `tf.Tensor`, any value that can be + converted to `tf.Tensor` using `tf.convert_to_tensor`, and any subclass + of `CompositeTensor` (such as `tf.RaggedTensor`). Returns: A `TypeSpec` that is compatible with `value`. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 60ae8ef5be9..e6274357a49 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -2432,6 +2432,10 @@ tf_module { name: "tuple" argspec: "args=[\'tensors\', \'name\', \'control_inputs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } + member_method { + name: "type_spec_from_value" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "unique" argspec: "args=[\'x\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 8928a9d3b67..468b4e36238 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -1108,6 +1108,10 @@ tf_module { name: "tuple" argspec: "args=[\'tensors\', \'control_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } + member_method { + name: "type_spec_from_value" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "unique" argspec: "args=[\'x\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], "