Add a padding_spec argument into tpu.replicate, which provides the functionality of auto bucketing inputs.

PiperOrigin-RevId: 295226850
Change-Id: I486a82739bd7d2629b218371f5b50f1e37cb4b2d
This commit is contained in:
Ruoxin Sang 2020-02-14 14:25:12 -08:00 committed by TensorFlower Gardener
parent da910d3f2d
commit c47458b7ba
27 changed files with 305 additions and 37 deletions

View File

@ -161,7 +161,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
""" """
return super(CentralStorageStrategy, self).experimental_local_results(value) return super(CentralStorageStrategy, self).experimental_local_results(value)
def experimental_run_v2(self, fn, args=(), kwargs=None): # pylint: disable=useless-super-delegation def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
"""Run `fn` on each replica, with the given arguments. """Run `fn` on each replica, with the given arguments.
In `CentralStorageStrategy`, `fn` is called on each of the compute In `CentralStorageStrategy`, `fn` is called on each of the compute
@ -171,12 +171,14 @@ class CentralStorageStrategy(distribute_lib.Strategy):
fn: The function to run. The output must be a `tf.nest` of `Tensor`s. fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`. args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`. kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns: Returns:
Return value from running `fn`. Return value from running `fn`.
""" """
return super(CentralStorageStrategy, self).experimental_run_v2(fn, args, return super(CentralStorageStrategy,
kwargs) self).experimental_run_v2(fn, args, kwargs, options)
def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
"""Reduce `value` across replicas. """Reduce `value` across replicas.

View File

@ -23,6 +23,7 @@ from absl.testing import parameterized
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -333,6 +334,31 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
# This assumes that there are exactly 2 replicas # This assumes that there are exactly 2 replicas
self.assertAllEqual([5.5, 7.], run(input_iterator)) self.assertAllEqual([5.5, 7.], run(input_iterator))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,
mode=["eager"]))
def testDynamicShapesWithRunOptions(self, distribution):
dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
options = distribute_lib.RunOptions
options.experimental_bucketizing_dynamic_shape = True
@def_function.function
def run(iterator):
def computation(x):
return math_ops.reduce_mean(x)
inputs = next(iterator)
outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(
computation, args=(inputs,), options=options))
return outputs
# This assumes that there are exactly 2 replicas
self.assertAllEqual([5.5, 7.], run(input_iterator))
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.multidevice_strategies, distribution=strategy_combinations.multidevice_strategies,

View File

@ -95,6 +95,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import copy import copy
import enum # pylint: disable=g-bad-import-order import enum # pylint: disable=g-bad-import-order
import threading import threading
@ -411,6 +412,35 @@ class InputContext(object):
self.input_pipeline_id, self.num_input_pipelines) self.input_pipeline_id, self.num_input_pipelines)
@tf_export("distribute.RunOptions")
class RunOptions(
collections.namedtuple("RunOptions", [
"experimental_enable_dynamic_batch_size",
"experimental_bucketizing_dynamic_shape",
])):
"""Run options for `strategy.experimental_run_v2`.
This can be used to hold some strategy specific configs.
Attributes:
experimental_enable_dynamic_batch_size: Boolean. Only applies to
TPUStrategy. Default to False. If True, TPUStrategy will enable dynamic
padder to support dynamic batch size for the inputs. Otherwise only static
shape inputs are allowed.
experimental_bucketizing_dynamic_shape: Boolean. Only applies to
TPUStrategy. Default to False. If True, TPUStrategy will automatic
bucketize inputs passed into `experimental_run_v2` if the input shape is
dynamic. This is a performance optimization to reduce XLA recompilation,
which should not have impact on correctness.
"""
def __new__(cls,
experimental_enable_dynamic_batch_size=True,
experimental_bucketizing_dynamic_shape=False):
return super(RunOptions,
cls).__new__(cls, experimental_enable_dynamic_batch_size,
experimental_bucketizing_dynamic_shape)
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Base classes for all distribution strategies. # Base classes for all distribution strategies.
@ -778,7 +808,7 @@ class StrategyBase(object):
return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access
dataset_fn) dataset_fn)
def experimental_run_v2(self, fn, args=(), kwargs=None): def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
"""Run `fn` on each replica, with the given arguments. """Run `fn` on each replica, with the given arguments.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
@ -800,6 +830,8 @@ class StrategyBase(object):
fn: The function to run. The output must be a `tf.nest` of `Tensor`s. fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`. args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`. kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns: Returns:
Merged return value of `fn` across replicas. The structure of the return Merged return value of `fn` across replicas. The structure of the return
@ -807,6 +839,8 @@ class StrategyBase(object):
structure can either be "per-replica" `Tensor` objects or `Tensor`s structure can either be "per-replica" `Tensor` objects or `Tensor`s
(for example, if running on a single replica). (for example, if running on a single replica).
""" """
del options
if not isinstance(args, (list, tuple)): if not isinstance(args, (list, tuple)):
raise ValueError( raise ValueError(
"positional args must be a list or tuple, got {}".format(type(args))) "positional args must be a list or tuple, got {}".format(type(args)))

View File

@ -163,7 +163,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
""" """
return super(OneDeviceStrategy, self).experimental_local_results(value) return super(OneDeviceStrategy, self).experimental_local_results(value)
def experimental_run_v2(self, fn, args=(), kwargs=None): # pylint: disable=useless-super-delegation def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
"""Run `fn` on each replica, with the given arguments. """Run `fn` on each replica, with the given arguments.
In `OneDeviceStrategy`, `fn` is simply called within a device scope for the In `OneDeviceStrategy`, `fn` is simply called within a device scope for the
@ -173,11 +173,14 @@ class OneDeviceStrategy(distribute_lib.Strategy):
fn: The function to run. The output must be a `tf.nest` of `Tensor`s. fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`. args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`. kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns: Returns:
Return value from running `fn`. Return value from running `fn`.
""" """
return super(OneDeviceStrategy, self).experimental_run_v2(fn, args, kwargs) return super(OneDeviceStrategy,
self).experimental_run_v2(fn, args, kwargs, options)
def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
"""Reduce `value` across replicas. """Reduce `value` across replicas.

View File

@ -158,14 +158,15 @@ class TPUStrategy(distribute_lib.Strategy):
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation. # can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed. # This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None): def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
"""See base class.""" """See base class."""
validate_experimental_run_function(fn) validate_experimental_run_function(fn)
# Note: the target function is converted to graph even when in Eager mode, # Note: the target function is converted to graph even when in Eager mode,
# so autograph is on by default here. # so autograph is on by default here.
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs) options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
@tf_export(v1=["distribute.experimental.TPUStrategy"]) @tf_export(v1=["distribute.experimental.TPUStrategy"])
@ -206,12 +207,62 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation. # can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed. # This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None): def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
"""See base class.""" """Run `fn` on each replica, with the given arguments.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
"per-replica" values, such as those produced by a "distributed `Dataset`",
when `fn` is executed on a particular replica, it will be executed with the
component of those "per-replica" values that correspond to that replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `all_reduce`.
All arguments in `args` or `kwargs` should either be nest of tensors or
per-replica objects containing tensors or composite tensors.
Users can pass strategy specific options to `options` argument. An example
to enable bucketizing dynamic shapes in `TPUStrategy.experimental_run_v2`
is:
```python
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(tpu='')
options = tf.distribute.RunOptions()
options.experimental_bucketizing_dynamic_shape = True
iterator = iter(inputs)
@tf.function()
def step_fn(inputs):
output = tf.reduce_sum(inputs)
return output
strategy.experimental_run_v2(step_fn, args=(next(iterator),),
options=options)
```
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be "per-replica" `Tensor` objects or `Tensor`s
(for example, if running on a single replica).
"""
validate_experimental_run_function(fn) validate_experimental_run_function(fn)
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs) options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. # TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
@ -288,7 +339,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
self._retrace_functions_for_each_device = False self._retrace_functions_for_each_device = False
self.experimental_enable_get_next_as_optional = True self.experimental_enable_get_next_as_optional = True
self.experimental_enable_dynamic_batch_size = True
self._prefetch_on_host = False self._prefetch_on_host = False
self._logical_device_stack = [0] self._logical_device_stack = [0]
@ -801,11 +851,11 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
""" """
return True return True
def tpu_run(self, fn, args, kwargs): def tpu_run(self, fn, args, kwargs, options=None):
func = self._tpu_function_creator(fn) func = self._tpu_function_creator(fn, options)
return func(args, kwargs) return func(args, kwargs)
def _tpu_function_creator(self, fn): def _tpu_function_creator(self, fn, options):
if fn in self._tpu_function_cache: if fn in self._tpu_function_cache:
return self._tpu_function_cache[fn] return self._tpu_function_cache[fn]
@ -844,7 +894,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# Construct and pass `maximum_shapes` so that we could support dynamic # Construct and pass `maximum_shapes` so that we could support dynamic
# shapes using dynamic padder. # shapes using dynamic padder.
if self.experimental_enable_dynamic_batch_size and replicate_inputs: if options.experimental_enable_dynamic_batch_size and replicate_inputs:
maximum_shapes = [] maximum_shapes = []
flattened_list = nest.flatten(replicate_inputs[0]) flattened_list = nest.flatten(replicate_inputs[0])
for input_tensor in flattened_list: for input_tensor in flattened_list:
@ -859,12 +909,18 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
else: else:
maximum_shapes = None maximum_shapes = None
if options.experimental_bucketizing_dynamic_shape:
padding_spec = tpu.PaddingSpec.POWER_OF_TWO
else:
padding_spec = None
with strategy.scope(): with strategy.scope():
replicate_outputs = tpu.replicate( replicate_outputs = tpu.replicate(
replicated_fn, replicated_fn,
replicate_inputs, replicate_inputs,
device_assignment=self._device_assignment, device_assignment=self._device_assignment,
maximum_shapes=maximum_shapes) maximum_shapes=maximum_shapes,
padding_spec=padding_spec)
# Remove all no ops that may have been added during 'tpu.replicate()' # Remove all no ops that may have been added during 'tpu.replicate()'
if isinstance(result[0], list): if isinstance(result[0], list):

View File

@ -20,6 +20,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import logging from absl import logging
import enum
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
@ -784,15 +786,58 @@ def outside_compilation(computation, *args, **kwargs):
return retval return retval
@tf_export(v1=["tpu.PaddingSpec"])
class PaddingSpec(enum.IntEnum):
"""Represents the type of padding policies for tpu.replicate."""
# By default the policy is set to AUTO, the dynamic input shape dimension will
# be pad to maximum of all the replicas.
AUTO = 0
# Bucketize the dynamic input shape dimension into a power of 2.
POWER_OF_TWO = 1
@tf_export(v1=["tpu.replicate"]) @tf_export(v1=["tpu.replicate"])
def replicate(computation, def replicate(computation,
inputs=None, inputs=None,
infeed_queue=None, infeed_queue=None,
device_assignment=None, device_assignment=None,
name=None, name=None,
maximum_shapes=None): maximum_shapes=None,
padding_spec=None):
"""Builds a graph operator that runs a replicated TPU computation. """Builds a graph operator that runs a replicated TPU computation.
Example for the basic usage that `inputs` has static shape:
```python
def computation(x):
x = x + 1
return tf.math.reduce_mean(x)
x = tf.convert_to_tensor([1., 2., 3.])
y = tf.convert_to_tensor([4., 5., 6.])
tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])
```
If the `inputs` has dynamic shapes and you would like to automatically
bucketize the inputs to avoid XLA recompilation. See the advanced example
below:
```python
def computation(x):
x = x + 1
return tf.math.reduce_mean(x)
# Assume input tensors in two replicas `x` and `y` both have dynamic shape
# ([None, 2]).
tf.compat.v1.tpu.replicate(
computation,
inputs=[x, y],
maximum_shapes=[tf.TensorShape([None, None])],
padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)
```
Args: Args:
computation: A Python function that builds the computation to replicate. computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to inputs: A list of lists of input tensors or `None` (equivalent to
@ -818,6 +863,11 @@ def replicate(computation,
object) will be padded to the maximum size of that dimension over all object) will be padded to the maximum size of that dimension over all
replicas. The structure of `maximum_shapes` needs to be the same as replicas. The structure of `maximum_shapes` needs to be the same as
`inputs[0]`. `inputs[0]`.
padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
Returns: Returns:
A list of outputs, indexed by `[replica_num]` each output can be a nested A list of outputs, indexed by `[replica_num]` each output can be a nested
structure same as what computation() returns with a few exceptions. structure same as what computation() returns with a few exceptions.
@ -845,10 +895,21 @@ def replicate(computation,
infeed_queue, infeed_queue,
device_assignment, device_assignment,
name, name,
maximum_shapes=maximum_shapes)[1] maximum_shapes=maximum_shapes,
padding_spec=padding_spec)[1]
def _pad_all_input(inputs, padded_shapes): def _ceil_to_pow_of_n(x, n):
"""Ceil input `x` to power of `n`."""
x = math_ops.cast(x, dtypes.float32)
lognx = math_ops.log(x) / math_ops.log(n * 1.0)
lognx = math_ops.ceil(lognx)
result = math_ops.pow(n * 1.0, lognx)
result = math_ops.cast(result, dtypes.int32)
return result
def _pad_all_input(inputs, padded_shapes, padding_spec):
"""Pad all input tensors given padded_shapes. """Pad all input tensors given padded_shapes.
The real shape tensors will be concatenated with the padded original inputs. The real shape tensors will be concatenated with the padded original inputs.
@ -856,6 +917,11 @@ def _pad_all_input(inputs, padded_shapes):
Args: Args:
inputs: The original inputs. inputs: The original inputs.
padded_shapes: A list of padded shapes for each input. padded_shapes: A list of padded shapes for each input.
padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
Returns: Returns:
The padded inputs and a PaddingMap list which maps the padded input The padded inputs and a PaddingMap list which maps the padded input
@ -933,6 +999,8 @@ def _pad_all_input(inputs, padded_shapes):
# among all the cores. # among all the cores.
max_dim_size = math_ops.maximum(maximum_shapes[idx][i], max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
minimum_dynamic_dim_size) minimum_dynamic_dim_size)
if padding_spec == PaddingSpec.POWER_OF_TWO:
max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2)
# Pad to the given maximum value. # Pad to the given maximum value.
padding = [0, max_dim_size - input_shape_tensor[i]] padding = [0, max_dim_size - input_shape_tensor[i]]
else: else:
@ -972,7 +1040,8 @@ def split_compile_and_replicate(computation,
device_assignment=None, device_assignment=None,
name=None, name=None,
use_tpu=True, use_tpu=True,
maximum_shapes=None): maximum_shapes=None,
padding_spec=None):
"""Builds graph operators that runs compilation and replicated computation. """Builds graph operators that runs compilation and replicated computation.
This is a lower level interface than replicate that returns a separate compile This is a lower level interface than replicate that returns a separate compile
@ -1009,6 +1078,11 @@ def split_compile_and_replicate(computation,
object) will be padded to the maximum size of that dimension over all object) will be padded to the maximum size of that dimension over all
replicas. The structure of `maximum_shapes` needs to be the same as replicas. The structure of `maximum_shapes` needs to be the same as
`inputs[0]`. `inputs[0]`.
padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
Returns: Returns:
A list of lists with the first list corresponding to the compile op and the A list of lists with the first list corresponding to the compile op and the
@ -1108,7 +1182,8 @@ def split_compile_and_replicate(computation,
tensor_shape.TensorShape(s) for s in flat_maximum_shapes tensor_shape.TensorShape(s) for s in flat_maximum_shapes
] ]
flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes) flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes,
padding_spec)
serialized_padding_maps = [] serialized_padding_maps = []
for padding_map in padding_maps: for padding_map in padding_maps:

View File

@ -46,7 +46,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "group" name: "group"

View File

@ -46,7 +46,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "group" name: "group"

View File

@ -0,0 +1,23 @@
path: "tensorflow.distribute.RunOptions"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "experimental_bucketizing_dynamic_shape"
mtype: "<type \'property\'>"
}
member {
name: "experimental_enable_dynamic_batch_size"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -45,7 +45,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "group" name: "group"

View File

@ -46,7 +46,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "group" name: "group"

View File

@ -46,7 +46,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "group" name: "group"

View File

@ -46,7 +46,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "group" name: "group"

View File

@ -50,7 +50,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "group" name: "group"

View File

@ -40,6 +40,10 @@ tf_module {
name: "ReplicaContext" name: "ReplicaContext"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "RunOptions"
mtype: "<type \'type\'>"
}
member { member {
name: "Server" name: "Server"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -0,0 +1,12 @@
path: "tensorflow.tpu.PaddingSpec"
tf_class {
is_instance: "<enum \'PaddingSpec\'>"
member {
name: "AUTO"
mtype: "<enum \'PaddingSpec\'>"
}
member {
name: "POWER_OF_TWO"
mtype: "<enum \'PaddingSpec\'>"
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "CrossShardOptimizer" name: "CrossShardOptimizer"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "PaddingSpec"
mtype: "<class \'enum.EnumMeta\'>"
}
member { member {
name: "experimental" name: "experimental"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"
@ -34,7 +38,7 @@ tf_module {
} }
member_method { member_method {
name: "replicate" name: "replicate"
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "rewrite" name: "rewrite"

View File

@ -54,7 +54,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "experimental_split_to_logical_devices" name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "experimental_split_to_logical_devices" name: "experimental_split_to_logical_devices"

View File

@ -0,0 +1,23 @@
path: "tensorflow.distribute.RunOptions"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "experimental_bucketizing_dynamic_shape"
mtype: "<type \'property\'>"
}
member {
name: "experimental_enable_dynamic_batch_size"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -53,7 +53,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "experimental_split_to_logical_devices" name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "experimental_split_to_logical_devices" name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "experimental_split_to_logical_devices" name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "experimental_split_to_logical_devices" name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
} }
member_method { member_method {
name: "experimental_run_v2" name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "experimental_split_to_logical_devices" name: "experimental_split_to_logical_devices"

View File

@ -40,6 +40,10 @@ tf_module {
name: "ReplicaContext" name: "ReplicaContext"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "RunOptions"
mtype: "<type \'type\'>"
}
member { member {
name: "Server" name: "Server"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -1328,6 +1328,8 @@ renames = {
'tf.compat.v1.to_int64', 'tf.compat.v1.to_int64',
'tf.tpu.CrossShardOptimizer': 'tf.tpu.CrossShardOptimizer':
'tf.compat.v1.tpu.CrossShardOptimizer', 'tf.compat.v1.tpu.CrossShardOptimizer',
'tf.tpu.PaddingSpec':
'tf.compat.v1.tpu.PaddingSpec',
'tf.tpu.batch_parallel': 'tf.tpu.batch_parallel':
'tf.compat.v1.tpu.batch_parallel', 'tf.compat.v1.tpu.batch_parallel',
'tf.tpu.bfloat16_scope': 'tf.tpu.bfloat16_scope':