Add a padding_spec
argument into tpu.replicate, which provides the functionality of auto bucketing inputs.
PiperOrigin-RevId: 295226850 Change-Id: I486a82739bd7d2629b218371f5b50f1e37cb4b2d
This commit is contained in:
parent
da910d3f2d
commit
c47458b7ba
tensorflow
python
distribute
central_storage_strategy.pycustom_training_loop_input_test.pydistribute_lib.pyone_device_strategy.pytpu_strategy.py
tpu
tools
api/golden
v1
tensorflow.distribute.-mirrored-strategy.pbtxttensorflow.distribute.-one-device-strategy.pbtxttensorflow.distribute.-run-options.pbtxttensorflow.distribute.-strategy.pbtxttensorflow.distribute.experimental.-central-storage-strategy.pbtxttensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxttensorflow.distribute.experimental.-parameter-server-strategy.pbtxttensorflow.distribute.experimental.-t-p-u-strategy.pbtxttensorflow.distribute.pbtxttensorflow.tpu.-padding-spec.pbtxttensorflow.tpu.pbtxt
v2
tensorflow.distribute.-mirrored-strategy.pbtxttensorflow.distribute.-one-device-strategy.pbtxttensorflow.distribute.-run-options.pbtxttensorflow.distribute.-strategy.pbtxttensorflow.distribute.experimental.-central-storage-strategy.pbtxttensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxttensorflow.distribute.experimental.-parameter-server-strategy.pbtxttensorflow.distribute.experimental.-t-p-u-strategy.pbtxttensorflow.distribute.pbtxt
compatibility
@ -161,7 +161,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
|
||||
"""
|
||||
return super(CentralStorageStrategy, self).experimental_local_results(value)
|
||||
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None): # pylint: disable=useless-super-delegation
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
|
||||
"""Run `fn` on each replica, with the given arguments.
|
||||
|
||||
In `CentralStorageStrategy`, `fn` is called on each of the compute
|
||||
@ -171,12 +171,14 @@ class CentralStorageStrategy(distribute_lib.Strategy):
|
||||
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
|
||||
args: (Optional) Positional arguments to `fn`.
|
||||
kwargs: (Optional) Keyword arguments to `fn`.
|
||||
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
|
||||
the options to run `fn`.
|
||||
|
||||
Returns:
|
||||
Return value from running `fn`.
|
||||
"""
|
||||
return super(CentralStorageStrategy, self).experimental_run_v2(fn, args,
|
||||
kwargs)
|
||||
return super(CentralStorageStrategy,
|
||||
self).experimental_run_v2(fn, args, kwargs, options)
|
||||
|
||||
def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
|
||||
"""Reduce `value` across replicas.
|
||||
|
@ -23,6 +23,7 @@ from absl.testing import parameterized
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -333,6 +334,31 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
# This assumes that there are exactly 2 replicas
|
||||
self.assertAllEqual([5.5, 7.], run(input_iterator))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.multidevice_strategies,
|
||||
mode=["eager"]))
|
||||
def testDynamicShapesWithRunOptions(self, distribution):
|
||||
dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
options = distribute_lib.RunOptions
|
||||
options.experimental_bucketizing_dynamic_shape = True
|
||||
|
||||
@def_function.function
|
||||
def run(iterator):
|
||||
|
||||
def computation(x):
|
||||
return math_ops.reduce_mean(x)
|
||||
|
||||
inputs = next(iterator)
|
||||
outputs = distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(
|
||||
computation, args=(inputs,), options=options))
|
||||
return outputs
|
||||
|
||||
# This assumes that there are exactly 2 replicas
|
||||
self.assertAllEqual([5.5, 7.], run(input_iterator))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.multidevice_strategies,
|
||||
|
@ -95,6 +95,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import threading
|
||||
@ -411,6 +412,35 @@ class InputContext(object):
|
||||
self.input_pipeline_id, self.num_input_pipelines)
|
||||
|
||||
|
||||
@tf_export("distribute.RunOptions")
|
||||
class RunOptions(
|
||||
collections.namedtuple("RunOptions", [
|
||||
"experimental_enable_dynamic_batch_size",
|
||||
"experimental_bucketizing_dynamic_shape",
|
||||
])):
|
||||
"""Run options for `strategy.experimental_run_v2`.
|
||||
|
||||
This can be used to hold some strategy specific configs.
|
||||
|
||||
Attributes:
|
||||
experimental_enable_dynamic_batch_size: Boolean. Only applies to
|
||||
TPUStrategy. Default to False. If True, TPUStrategy will enable dynamic
|
||||
padder to support dynamic batch size for the inputs. Otherwise only static
|
||||
shape inputs are allowed.
|
||||
experimental_bucketizing_dynamic_shape: Boolean. Only applies to
|
||||
TPUStrategy. Default to False. If True, TPUStrategy will automatic
|
||||
bucketize inputs passed into `experimental_run_v2` if the input shape is
|
||||
dynamic. This is a performance optimization to reduce XLA recompilation,
|
||||
which should not have impact on correctness.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
experimental_enable_dynamic_batch_size=True,
|
||||
experimental_bucketizing_dynamic_shape=False):
|
||||
return super(RunOptions,
|
||||
cls).__new__(cls, experimental_enable_dynamic_batch_size,
|
||||
experimental_bucketizing_dynamic_shape)
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Base classes for all distribution strategies.
|
||||
|
||||
@ -778,7 +808,7 @@ class StrategyBase(object):
|
||||
return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access
|
||||
dataset_fn)
|
||||
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
|
||||
"""Run `fn` on each replica, with the given arguments.
|
||||
|
||||
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
|
||||
@ -800,6 +830,8 @@ class StrategyBase(object):
|
||||
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
|
||||
args: (Optional) Positional arguments to `fn`.
|
||||
kwargs: (Optional) Keyword arguments to `fn`.
|
||||
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
|
||||
the options to run `fn`.
|
||||
|
||||
Returns:
|
||||
Merged return value of `fn` across replicas. The structure of the return
|
||||
@ -807,6 +839,8 @@ class StrategyBase(object):
|
||||
structure can either be "per-replica" `Tensor` objects or `Tensor`s
|
||||
(for example, if running on a single replica).
|
||||
"""
|
||||
del options
|
||||
|
||||
if not isinstance(args, (list, tuple)):
|
||||
raise ValueError(
|
||||
"positional args must be a list or tuple, got {}".format(type(args)))
|
||||
|
@ -163,7 +163,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
"""
|
||||
return super(OneDeviceStrategy, self).experimental_local_results(value)
|
||||
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None): # pylint: disable=useless-super-delegation
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
|
||||
"""Run `fn` on each replica, with the given arguments.
|
||||
|
||||
In `OneDeviceStrategy`, `fn` is simply called within a device scope for the
|
||||
@ -173,11 +173,14 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
|
||||
args: (Optional) Positional arguments to `fn`.
|
||||
kwargs: (Optional) Keyword arguments to `fn`.
|
||||
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
|
||||
the options to run `fn`.
|
||||
|
||||
Returns:
|
||||
Return value from running `fn`.
|
||||
"""
|
||||
return super(OneDeviceStrategy, self).experimental_run_v2(fn, args, kwargs)
|
||||
return super(OneDeviceStrategy,
|
||||
self).experimental_run_v2(fn, args, kwargs, options)
|
||||
|
||||
def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
|
||||
"""Reduce `value` across replicas.
|
||||
|
@ -158,14 +158,15 @@ class TPUStrategy(distribute_lib.Strategy):
|
||||
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
||||
# can use the default implementation.
|
||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
|
||||
"""See base class."""
|
||||
validate_experimental_run_function(fn)
|
||||
|
||||
# Note: the target function is converted to graph even when in Eager mode,
|
||||
# so autograph is on by default here.
|
||||
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
|
||||
return self.extended.tpu_run(fn, args, kwargs)
|
||||
options = options or distribute_lib.RunOptions()
|
||||
return self.extended.tpu_run(fn, args, kwargs, options)
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.experimental.TPUStrategy"])
|
||||
@ -206,12 +207,62 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
|
||||
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
|
||||
# can use the default implementation.
|
||||
# This implementation runs a single step. It does not use infeed or outfeed.
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
"""See base class."""
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
|
||||
"""Run `fn` on each replica, with the given arguments.
|
||||
|
||||
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
|
||||
"per-replica" values, such as those produced by a "distributed `Dataset`",
|
||||
when `fn` is executed on a particular replica, it will be executed with the
|
||||
component of those "per-replica" values that correspond to that replica.
|
||||
|
||||
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||
as `all_reduce`.
|
||||
|
||||
All arguments in `args` or `kwargs` should either be nest of tensors or
|
||||
per-replica objects containing tensors or composite tensors.
|
||||
|
||||
Users can pass strategy specific options to `options` argument. An example
|
||||
to enable bucketizing dynamic shapes in `TPUStrategy.experimental_run_v2`
|
||||
is:
|
||||
```python
|
||||
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(tpu='')
|
||||
|
||||
options = tf.distribute.RunOptions()
|
||||
options.experimental_bucketizing_dynamic_shape = True
|
||||
|
||||
iterator = iter(inputs)
|
||||
|
||||
@tf.function()
|
||||
def step_fn(inputs):
|
||||
output = tf.reduce_sum(inputs)
|
||||
return output
|
||||
|
||||
strategy.experimental_run_v2(step_fn, args=(next(iterator),),
|
||||
options=options)
|
||||
```
|
||||
|
||||
Args:
|
||||
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
|
||||
args: (Optional) Positional arguments to `fn`.
|
||||
kwargs: (Optional) Keyword arguments to `fn`.
|
||||
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
|
||||
the options to run `fn`.
|
||||
|
||||
Returns:
|
||||
Merged return value of `fn` across replicas. The structure of the return
|
||||
value is the same as the return value from `fn`. Each element in the
|
||||
structure can either be "per-replica" `Tensor` objects or `Tensor`s
|
||||
(for example, if running on a single replica).
|
||||
"""
|
||||
validate_experimental_run_function(fn)
|
||||
|
||||
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
|
||||
return self.extended.tpu_run(fn, args, kwargs)
|
||||
options = options or distribute_lib.RunOptions()
|
||||
return self.extended.tpu_run(fn, args, kwargs, options)
|
||||
|
||||
|
||||
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
|
||||
@ -288,7 +339,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._retrace_functions_for_each_device = False
|
||||
|
||||
self.experimental_enable_get_next_as_optional = True
|
||||
self.experimental_enable_dynamic_batch_size = True
|
||||
self._prefetch_on_host = False
|
||||
|
||||
self._logical_device_stack = [0]
|
||||
@ -801,11 +851,11 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
"""
|
||||
return True
|
||||
|
||||
def tpu_run(self, fn, args, kwargs):
|
||||
func = self._tpu_function_creator(fn)
|
||||
def tpu_run(self, fn, args, kwargs, options=None):
|
||||
func = self._tpu_function_creator(fn, options)
|
||||
return func(args, kwargs)
|
||||
|
||||
def _tpu_function_creator(self, fn):
|
||||
def _tpu_function_creator(self, fn, options):
|
||||
if fn in self._tpu_function_cache:
|
||||
return self._tpu_function_cache[fn]
|
||||
|
||||
@ -844,7 +894,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
# Construct and pass `maximum_shapes` so that we could support dynamic
|
||||
# shapes using dynamic padder.
|
||||
if self.experimental_enable_dynamic_batch_size and replicate_inputs:
|
||||
if options.experimental_enable_dynamic_batch_size and replicate_inputs:
|
||||
maximum_shapes = []
|
||||
flattened_list = nest.flatten(replicate_inputs[0])
|
||||
for input_tensor in flattened_list:
|
||||
@ -859,12 +909,18 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
else:
|
||||
maximum_shapes = None
|
||||
|
||||
if options.experimental_bucketizing_dynamic_shape:
|
||||
padding_spec = tpu.PaddingSpec.POWER_OF_TWO
|
||||
else:
|
||||
padding_spec = None
|
||||
|
||||
with strategy.scope():
|
||||
replicate_outputs = tpu.replicate(
|
||||
replicated_fn,
|
||||
replicate_inputs,
|
||||
device_assignment=self._device_assignment,
|
||||
maximum_shapes=maximum_shapes)
|
||||
maximum_shapes=maximum_shapes,
|
||||
padding_spec=padding_spec)
|
||||
|
||||
# Remove all no ops that may have been added during 'tpu.replicate()'
|
||||
if isinstance(result[0], list):
|
||||
|
@ -20,6 +20,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
import enum
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
@ -784,15 +786,58 @@ def outside_compilation(computation, *args, **kwargs):
|
||||
return retval
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.PaddingSpec"])
|
||||
class PaddingSpec(enum.IntEnum):
|
||||
"""Represents the type of padding policies for tpu.replicate."""
|
||||
# By default the policy is set to AUTO, the dynamic input shape dimension will
|
||||
# be pad to maximum of all the replicas.
|
||||
AUTO = 0
|
||||
# Bucketize the dynamic input shape dimension into a power of 2.
|
||||
POWER_OF_TWO = 1
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.replicate"])
|
||||
def replicate(computation,
|
||||
inputs=None,
|
||||
infeed_queue=None,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
maximum_shapes=None):
|
||||
maximum_shapes=None,
|
||||
padding_spec=None):
|
||||
"""Builds a graph operator that runs a replicated TPU computation.
|
||||
|
||||
Example for the basic usage that `inputs` has static shape:
|
||||
|
||||
```python
|
||||
|
||||
def computation(x):
|
||||
x = x + 1
|
||||
return tf.math.reduce_mean(x)
|
||||
|
||||
x = tf.convert_to_tensor([1., 2., 3.])
|
||||
y = tf.convert_to_tensor([4., 5., 6.])
|
||||
tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])
|
||||
```
|
||||
|
||||
If the `inputs` has dynamic shapes and you would like to automatically
|
||||
bucketize the inputs to avoid XLA recompilation. See the advanced example
|
||||
below:
|
||||
|
||||
```python
|
||||
|
||||
def computation(x):
|
||||
x = x + 1
|
||||
return tf.math.reduce_mean(x)
|
||||
|
||||
# Assume input tensors in two replicas `x` and `y` both have dynamic shape
|
||||
# ([None, 2]).
|
||||
tf.compat.v1.tpu.replicate(
|
||||
computation,
|
||||
inputs=[x, y],
|
||||
maximum_shapes=[tf.TensorShape([None, None])],
|
||||
padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)
|
||||
```
|
||||
|
||||
Args:
|
||||
computation: A Python function that builds the computation to replicate.
|
||||
inputs: A list of lists of input tensors or `None` (equivalent to
|
||||
@ -818,6 +863,11 @@ def replicate(computation,
|
||||
object) will be padded to the maximum size of that dimension over all
|
||||
replicas. The structure of `maximum_shapes` needs to be the same as
|
||||
`inputs[0]`.
|
||||
padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
|
||||
padding policy when the `inputs` to `tpu.replicate` is dynamic.
|
||||
One usage is to enable automatic bucketizing on the inputs by setting the
|
||||
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
||||
recompilation in the XLA side.
|
||||
Returns:
|
||||
A list of outputs, indexed by `[replica_num]` each output can be a nested
|
||||
structure same as what computation() returns with a few exceptions.
|
||||
@ -845,10 +895,21 @@ def replicate(computation,
|
||||
infeed_queue,
|
||||
device_assignment,
|
||||
name,
|
||||
maximum_shapes=maximum_shapes)[1]
|
||||
maximum_shapes=maximum_shapes,
|
||||
padding_spec=padding_spec)[1]
|
||||
|
||||
|
||||
def _pad_all_input(inputs, padded_shapes):
|
||||
def _ceil_to_pow_of_n(x, n):
|
||||
"""Ceil input `x` to power of `n`."""
|
||||
x = math_ops.cast(x, dtypes.float32)
|
||||
lognx = math_ops.log(x) / math_ops.log(n * 1.0)
|
||||
lognx = math_ops.ceil(lognx)
|
||||
result = math_ops.pow(n * 1.0, lognx)
|
||||
result = math_ops.cast(result, dtypes.int32)
|
||||
return result
|
||||
|
||||
|
||||
def _pad_all_input(inputs, padded_shapes, padding_spec):
|
||||
"""Pad all input tensors given padded_shapes.
|
||||
|
||||
The real shape tensors will be concatenated with the padded original inputs.
|
||||
@ -856,6 +917,11 @@ def _pad_all_input(inputs, padded_shapes):
|
||||
Args:
|
||||
inputs: The original inputs.
|
||||
padded_shapes: A list of padded shapes for each input.
|
||||
padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
|
||||
padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
|
||||
One usage is to enable automatic bucketizing on the inputs by setting the
|
||||
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
||||
recompilation in the XLA side.
|
||||
|
||||
Returns:
|
||||
The padded inputs and a PaddingMap list which maps the padded input
|
||||
@ -933,6 +999,8 @@ def _pad_all_input(inputs, padded_shapes):
|
||||
# among all the cores.
|
||||
max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
|
||||
minimum_dynamic_dim_size)
|
||||
if padding_spec == PaddingSpec.POWER_OF_TWO:
|
||||
max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2)
|
||||
# Pad to the given maximum value.
|
||||
padding = [0, max_dim_size - input_shape_tensor[i]]
|
||||
else:
|
||||
@ -972,7 +1040,8 @@ def split_compile_and_replicate(computation,
|
||||
device_assignment=None,
|
||||
name=None,
|
||||
use_tpu=True,
|
||||
maximum_shapes=None):
|
||||
maximum_shapes=None,
|
||||
padding_spec=None):
|
||||
"""Builds graph operators that runs compilation and replicated computation.
|
||||
|
||||
This is a lower level interface than replicate that returns a separate compile
|
||||
@ -1009,6 +1078,11 @@ def split_compile_and_replicate(computation,
|
||||
object) will be padded to the maximum size of that dimension over all
|
||||
replicas. The structure of `maximum_shapes` needs to be the same as
|
||||
`inputs[0]`.
|
||||
padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the
|
||||
padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
|
||||
One usage is to enable automatic bucketizing on the inputs by setting the
|
||||
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
|
||||
recompilation in the XLA side.
|
||||
|
||||
Returns:
|
||||
A list of lists with the first list corresponding to the compile op and the
|
||||
@ -1108,7 +1182,8 @@ def split_compile_and_replicate(computation,
|
||||
tensor_shape.TensorShape(s) for s in flat_maximum_shapes
|
||||
]
|
||||
|
||||
flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes)
|
||||
flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes,
|
||||
padding_spec)
|
||||
|
||||
serialized_padding_maps = []
|
||||
for padding_map in padding_maps:
|
||||
|
@ -46,7 +46,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
|
@ -46,7 +46,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
|
@ -0,0 +1,23 @@
|
||||
path: "tensorflow.distribute.RunOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "experimental_bucketizing_dynamic_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_enable_dynamic_batch_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "count"
|
||||
}
|
||||
member_method {
|
||||
name: "index"
|
||||
}
|
||||
}
|
@ -45,7 +45,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
|
@ -46,7 +46,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
|
@ -46,7 +46,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
|
@ -46,7 +46,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
|
@ -50,7 +50,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
|
@ -40,6 +40,10 @@ tf_module {
|
||||
name: "ReplicaContext"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RunOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Server"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.tpu.PaddingSpec"
|
||||
tf_class {
|
||||
is_instance: "<enum \'PaddingSpec\'>"
|
||||
member {
|
||||
name: "AUTO"
|
||||
mtype: "<enum \'PaddingSpec\'>"
|
||||
}
|
||||
member {
|
||||
name: "POWER_OF_TWO"
|
||||
mtype: "<enum \'PaddingSpec\'>"
|
||||
}
|
||||
}
|
@ -4,6 +4,10 @@ tf_module {
|
||||
name: "CrossShardOptimizer"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "PaddingSpec"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
@ -34,7 +38,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "replicate"
|
||||
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "rewrite"
|
||||
|
@ -54,7 +54,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
|
@ -54,7 +54,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
|
@ -0,0 +1,23 @@
|
||||
path: "tensorflow.distribute.RunOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "experimental_bucketizing_dynamic_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_enable_dynamic_batch_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "count"
|
||||
}
|
||||
member_method {
|
||||
name: "index"
|
||||
}
|
||||
}
|
@ -53,7 +53,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
|
@ -54,7 +54,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
|
@ -54,7 +54,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
|
@ -54,7 +54,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
|
@ -54,7 +54,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
|
@ -40,6 +40,10 @@ tf_module {
|
||||
name: "ReplicaContext"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RunOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Server"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -1328,6 +1328,8 @@ renames = {
|
||||
'tf.compat.v1.to_int64',
|
||||
'tf.tpu.CrossShardOptimizer':
|
||||
'tf.compat.v1.tpu.CrossShardOptimizer',
|
||||
'tf.tpu.PaddingSpec':
|
||||
'tf.compat.v1.tpu.PaddingSpec',
|
||||
'tf.tpu.batch_parallel':
|
||||
'tf.compat.v1.tpu.batch_parallel',
|
||||
'tf.tpu.bfloat16_scope':
|
||||
|
Loading…
Reference in New Issue
Block a user