[core API doc fixit] tf.get_static_value; updated the API doc with more explanations and examples.
PiperOrigin-RevId: 348495511 Change-Id: I28b6c2c4822fd2d62eabbb535282ac5087765fae
This commit is contained in:
parent
9d10fa31f0
commit
b8a634fd5a
@ -809,6 +809,41 @@ def constant_value(tensor, partial=False): # pylint: disable=invalid-name
|
||||
This function attempts to partially evaluate the given tensor, and
|
||||
returns its value as a numpy ndarray if this succeeds.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> a = tf.constant(10)
|
||||
>>> tf.get_static_value(a)
|
||||
10
|
||||
>>> b = tf.constant(20)
|
||||
>>> tf.get_static_value(tf.add(a, b))
|
||||
30
|
||||
|
||||
>>> # `tf.Variable` is not supported.
|
||||
>>> c = tf.Variable(30)
|
||||
>>> print(tf.get_static_value(c))
|
||||
None
|
||||
|
||||
Using `partial` option is most relevant when calling `get_static_value` inside
|
||||
a `tf.function`. Setting it to `True` will return the results but for the
|
||||
values that cannot be evaluated will be `None`. For example:
|
||||
|
||||
```python
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
self.a = tf.Variable(1)
|
||||
self.b = tf.constant(2)
|
||||
|
||||
@tf.function
|
||||
def bar(self, partial):
|
||||
packed = tf.raw_ops.Pack(values=[self.a, self.b])
|
||||
static_val = tf.get_static_value(packed, partial=partial)
|
||||
tf.print(static_val)
|
||||
|
||||
f = Foo()
|
||||
f.bar(partial=True) # `array([None, array(2, dtype=int32)], dtype=object)`
|
||||
f.bar(partial=False) # `None`
|
||||
```
|
||||
|
||||
Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it
|
||||
will no longer be possible to feed a different value for `tensor`. This allows
|
||||
the result of this function to influence the graph that is constructed, and
|
||||
|
Loading…
Reference in New Issue
Block a user