Add a padding_spec argument into tpu.replicate, which provides the functionality of auto bucketing inputs.

PiperOrigin-RevId: 295226850
Change-Id: I486a82739bd7d2629b218371f5b50f1e37cb4b2d
This commit is contained in:
Ruoxin Sang 2020-02-14 14:25:12 -08:00 committed by TensorFlower Gardener
parent da910d3f2d
commit c47458b7ba
27 changed files with 305 additions and 37 deletions

View File

@ -161,7 +161,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
"""
return super(CentralStorageStrategy, self).experimental_local_results(value)
def experimental_run_v2(self, fn, args=(), kwargs=None): # pylint: disable=useless-super-delegation
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
"""Run `fn` on each replica, with the given arguments.
In `CentralStorageStrategy`, `fn` is called on each of the compute
@ -171,12 +171,14 @@ class CentralStorageStrategy(distribute_lib.Strategy):
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Return value from running `fn`.
"""
return super(CentralStorageStrategy, self).experimental_run_v2(fn, args,
kwargs)
return super(CentralStorageStrategy,
self).experimental_run_v2(fn, args, kwargs, options)
def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
"""Reduce `value` across replicas.

View File

@ -23,6 +23,7 @@ from absl.testing import parameterized
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function
@ -333,6 +334,31 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
# This assumes that there are exactly 2 replicas
self.assertAllEqual([5.5, 7.], run(input_iterator))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,
mode=["eager"]))
def testDynamicShapesWithRunOptions(self, distribution):
dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
options = distribute_lib.RunOptions
options.experimental_bucketizing_dynamic_shape = True
@def_function.function
def run(iterator):
def computation(x):
return math_ops.reduce_mean(x)
inputs = next(iterator)
outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(
computation, args=(inputs,), options=options))
return outputs
# This assumes that there are exactly 2 replicas
self.assertAllEqual([5.5, 7.], run(input_iterator))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,

View File

@ -95,6 +95,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import enum # pylint: disable=g-bad-import-order
import threading
@ -411,6 +412,35 @@ class InputContext(object):
self.input_pipeline_id, self.num_input_pipelines)
@tf_export("distribute.RunOptions")
class RunOptions(
collections.namedtuple("RunOptions", [
"experimental_enable_dynamic_batch_size",
"experimental_bucketizing_dynamic_shape",
])):
"""Run options for `strategy.experimental_run_v2`.
This can be used to hold some strategy specific configs.
Attributes:
experimental_enable_dynamic_batch_size: Boolean. Only applies to
TPUStrategy. Default to False. If True, TPUStrategy will enable dynamic
padder to support dynamic batch size for the inputs. Otherwise only static
shape inputs are allowed.
experimental_bucketizing_dynamic_shape: Boolean. Only applies to
TPUStrategy. Default to False. If True, TPUStrategy will automatic
bucketize inputs passed into `experimental_run_v2` if the input shape is
dynamic. This is a performance optimization to reduce XLA recompilation,
which should not have impact on correctness.
"""
def __new__(cls,
experimental_enable_dynamic_batch_size=True,
experimental_bucketizing_dynamic_shape=False):
return super(RunOptions,
cls).__new__(cls, experimental_enable_dynamic_batch_size,
experimental_bucketizing_dynamic_shape)
# ------------------------------------------------------------------------------
# Base classes for all distribution strategies.
@ -778,7 +808,7 @@ class StrategyBase(object):
return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access
dataset_fn)
def experimental_run_v2(self, fn, args=(), kwargs=None):
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
"""Run `fn` on each replica, with the given arguments.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
@ -800,6 +830,8 @@ class StrategyBase(object):
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
@ -807,6 +839,8 @@ class StrategyBase(object):
structure can either be "per-replica" `Tensor` objects or `Tensor`s
(for example, if running on a single replica).
"""
del options
if not isinstance(args, (list, tuple)):
raise ValueError(
"positional args must be a list or tuple, got {}".format(type(args)))

View File

@ -163,7 +163,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
"""
return super(OneDeviceStrategy, self).experimental_local_results(value)
def experimental_run_v2(self, fn, args=(), kwargs=None): # pylint: disable=useless-super-delegation
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation
"""Run `fn` on each replica, with the given arguments.
In `OneDeviceStrategy`, `fn` is simply called within a device scope for the
@ -173,11 +173,14 @@ class OneDeviceStrategy(distribute_lib.Strategy):
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Return value from running `fn`.
"""
return super(OneDeviceStrategy, self).experimental_run_v2(fn, args, kwargs)
return super(OneDeviceStrategy,
self).experimental_run_v2(fn, args, kwargs, options)
def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation
"""Reduce `value` across replicas.

View File

@ -158,14 +158,15 @@ class TPUStrategy(distribute_lib.Strategy):
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None):
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
"""See base class."""
validate_experimental_run_function(fn)
# Note: the target function is converted to graph even when in Eager mode,
# so autograph is on by default here.
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs)
options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
@tf_export(v1=["distribute.experimental.TPUStrategy"])
@ -206,12 +207,62 @@ class TPUStrategyV1(distribute_lib.StrategyV1):
# TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
# can use the default implementation.
# This implementation runs a single step. It does not use infeed or outfeed.
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class."""
def experimental_run_v2(self, fn, args=(), kwargs=None, options=None):
"""Run `fn` on each replica, with the given arguments.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
"per-replica" values, such as those produced by a "distributed `Dataset`",
when `fn` is executed on a particular replica, it will be executed with the
component of those "per-replica" values that correspond to that replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `all_reduce`.
All arguments in `args` or `kwargs` should either be nest of tensors or
per-replica objects containing tensors or composite tensors.
Users can pass strategy specific options to `options` argument. An example
to enable bucketizing dynamic shapes in `TPUStrategy.experimental_run_v2`
is:
```python
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(tpu='')
options = tf.distribute.RunOptions()
options.experimental_bucketizing_dynamic_shape = True
iterator = iter(inputs)
@tf.function()
def step_fn(inputs):
output = tf.reduce_sum(inputs)
return output
strategy.experimental_run_v2(step_fn, args=(next(iterator),),
options=options)
```
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns:
Merged return value of `fn` across replicas. The structure of the return
value is the same as the return value from `fn`. Each element in the
structure can either be "per-replica" `Tensor` objects or `Tensor`s
(for example, if running on a single replica).
"""
validate_experimental_run_function(fn)
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs)
options = options or distribute_lib.RunOptions()
return self.extended.tpu_run(fn, args, kwargs, options)
# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
@ -288,7 +339,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
self._retrace_functions_for_each_device = False
self.experimental_enable_get_next_as_optional = True
self.experimental_enable_dynamic_batch_size = True
self._prefetch_on_host = False
self._logical_device_stack = [0]
@ -801,11 +851,11 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
"""
return True
def tpu_run(self, fn, args, kwargs):
func = self._tpu_function_creator(fn)
def tpu_run(self, fn, args, kwargs, options=None):
func = self._tpu_function_creator(fn, options)
return func(args, kwargs)
def _tpu_function_creator(self, fn):
def _tpu_function_creator(self, fn, options):
if fn in self._tpu_function_cache:
return self._tpu_function_cache[fn]
@ -844,7 +894,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# Construct and pass `maximum_shapes` so that we could support dynamic
# shapes using dynamic padder.
if self.experimental_enable_dynamic_batch_size and replicate_inputs:
if options.experimental_enable_dynamic_batch_size and replicate_inputs:
maximum_shapes = []
flattened_list = nest.flatten(replicate_inputs[0])
for input_tensor in flattened_list:
@ -859,12 +909,18 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
else:
maximum_shapes = None
if options.experimental_bucketizing_dynamic_shape:
padding_spec = tpu.PaddingSpec.POWER_OF_TWO
else:
padding_spec = None
with strategy.scope():
replicate_outputs = tpu.replicate(
replicated_fn,
replicate_inputs,
device_assignment=self._device_assignment,
maximum_shapes=maximum_shapes)
maximum_shapes=maximum_shapes,
padding_spec=padding_spec)
# Remove all no ops that may have been added during 'tpu.replicate()'
if isinstance(result[0], list):

View File

@ -20,6 +20,8 @@ from __future__ import division
from __future__ import print_function
from absl import logging
import enum
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@ -784,15 +786,58 @@ def outside_compilation(computation, *args, **kwargs):
return retval
@tf_export(v1=["tpu.PaddingSpec"])
class PaddingSpec(enum.IntEnum):
"""Represents the type of padding policies for tpu.replicate."""
# By default the policy is set to AUTO, the dynamic input shape dimension will
# be pad to maximum of all the replicas.
AUTO = 0
# Bucketize the dynamic input shape dimension into a power of 2.
POWER_OF_TWO = 1
@tf_export(v1=["tpu.replicate"])
def replicate(computation,
inputs=None,
infeed_queue=None,
device_assignment=None,
name=None,
maximum_shapes=None):
maximum_shapes=None,
padding_spec=None):
"""Builds a graph operator that runs a replicated TPU computation.
Example for the basic usage that `inputs` has static shape:
```python
def computation(x):
x = x + 1
return tf.math.reduce_mean(x)
x = tf.convert_to_tensor([1., 2., 3.])
y = tf.convert_to_tensor([4., 5., 6.])
tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])
```
If the `inputs` has dynamic shapes and you would like to automatically
bucketize the inputs to avoid XLA recompilation. See the advanced example
below:
```python
def computation(x):
x = x + 1
return tf.math.reduce_mean(x)
# Assume input tensors in two replicas `x` and `y` both have dynamic shape
# ([None, 2]).
tf.compat.v1.tpu.replicate(
computation,
inputs=[x, y],
maximum_shapes=[tf.TensorShape([None, None])],
padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)
```
Args:
computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to
@ -818,6 +863,11 @@ def replicate(computation,
object) will be padded to the maximum size of that dimension over all
replicas. The structure of `maximum_shapes` needs to be the same as
`inputs[0]`.
padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
Returns:
A list of outputs, indexed by `[replica_num]` each output can be a nested
structure same as what computation() returns with a few exceptions.
@ -845,10 +895,21 @@ def replicate(computation,
infeed_queue,
device_assignment,
name,
maximum_shapes=maximum_shapes)[1]
maximum_shapes=maximum_shapes,
padding_spec=padding_spec)[1]
def _pad_all_input(inputs, padded_shapes):
def _ceil_to_pow_of_n(x, n):
"""Ceil input `x` to power of `n`."""
x = math_ops.cast(x, dtypes.float32)
lognx = math_ops.log(x) / math_ops.log(n * 1.0)
lognx = math_ops.ceil(lognx)
result = math_ops.pow(n * 1.0, lognx)
result = math_ops.cast(result, dtypes.int32)
return result
def _pad_all_input(inputs, padded_shapes, padding_spec):
"""Pad all input tensors given padded_shapes.
The real shape tensors will be concatenated with the padded original inputs.
@ -856,6 +917,11 @@ def _pad_all_input(inputs, padded_shapes):
Args:
inputs: The original inputs.
padded_shapes: A list of padded shapes for each input.
padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
Returns:
The padded inputs and a PaddingMap list which maps the padded input
@ -933,6 +999,8 @@ def _pad_all_input(inputs, padded_shapes):
# among all the cores.
max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
minimum_dynamic_dim_size)
if padding_spec == PaddingSpec.POWER_OF_TWO:
max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2)
# Pad to the given maximum value.
padding = [0, max_dim_size - input_shape_tensor[i]]
else:
@ -972,7 +1040,8 @@ def split_compile_and_replicate(computation,
device_assignment=None,
name=None,
use_tpu=True,
maximum_shapes=None):
maximum_shapes=None,
padding_spec=None):
"""Builds graph operators that runs compilation and replicated computation.
This is a lower level interface than replicate that returns a separate compile
@ -1009,6 +1078,11 @@ def split_compile_and_replicate(computation,
object) will be padded to the maximum size of that dimension over all
replicas. The structure of `maximum_shapes` needs to be the same as
`inputs[0]`.
padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
Returns:
A list of lists with the first list corresponding to the compile op and the
@ -1108,7 +1182,8 @@ def split_compile_and_replicate(computation,
tensor_shape.TensorShape(s) for s in flat_maximum_shapes
]
flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes)
flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes,
padding_spec)
serialized_padding_maps = []
for padding_map in padding_maps:

View File

@ -46,7 +46,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "group"

View File

@ -46,7 +46,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "group"

View File

@ -0,0 +1,23 @@
path: "tensorflow.distribute.RunOptions"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "experimental_bucketizing_dynamic_shape"
mtype: "<type \'property\'>"
}
member {
name: "experimental_enable_dynamic_batch_size"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -45,7 +45,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "group"

View File

@ -46,7 +46,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "group"

View File

@ -46,7 +46,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "group"

View File

@ -46,7 +46,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "group"

View File

@ -50,7 +50,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "group"

View File

@ -40,6 +40,10 @@ tf_module {
name: "ReplicaContext"
mtype: "<type \'type\'>"
}
member {
name: "RunOptions"
mtype: "<type \'type\'>"
}
member {
name: "Server"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,12 @@
path: "tensorflow.tpu.PaddingSpec"
tf_class {
is_instance: "<enum \'PaddingSpec\'>"
member {
name: "AUTO"
mtype: "<enum \'PaddingSpec\'>"
}
member {
name: "POWER_OF_TWO"
mtype: "<enum \'PaddingSpec\'>"
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "CrossShardOptimizer"
mtype: "<type \'type\'>"
}
member {
name: "PaddingSpec"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
@ -34,7 +38,7 @@ tf_module {
}
member_method {
name: "replicate"
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'computation\', \'inputs\', \'infeed_queue\', \'device_assignment\', \'name\', \'maximum_shapes\', \'padding_spec\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "rewrite"

View File

@ -54,7 +54,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "experimental_split_to_logical_devices"

View File

@ -0,0 +1,23 @@
path: "tensorflow.distribute.RunOptions"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.RunOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "experimental_bucketizing_dynamic_shape"
mtype: "<type \'property\'>"
}
member {
name: "experimental_enable_dynamic_batch_size"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -53,7 +53,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "experimental_split_to_logical_devices"

View File

@ -54,7 +54,7 @@ tf_class {
}
member_method {
name: "experimental_run_v2"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "experimental_split_to_logical_devices"

View File

@ -40,6 +40,10 @@ tf_module {
name: "ReplicaContext"
mtype: "<type \'type\'>"
}
member {
name: "RunOptions"
mtype: "<type \'type\'>"
}
member {
name: "Server"
mtype: "<type \'type\'>"

View File

@ -1328,6 +1328,8 @@ renames = {
'tf.compat.v1.to_int64',
'tf.tpu.CrossShardOptimizer':
'tf.compat.v1.tpu.CrossShardOptimizer',
'tf.tpu.PaddingSpec':
'tf.compat.v1.tpu.PaddingSpec',
'tf.tpu.batch_parallel':
'tf.compat.v1.tpu.batch_parallel',
'tf.tpu.bfloat16_scope':