Add API to enable spatial partition in TPUStrategy.
PiperOrigin-RevId: 293454256 Change-Id: I70e23d94f2fc0ebb10acb7124fbb20e4f72d11fc
This commit is contained in:
parent
8a9c9b6af7
commit
f95a6caa8b
@ -532,6 +532,7 @@ py_library(
|
||||
":numpy_dataset",
|
||||
":reduce_util",
|
||||
":values",
|
||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
|
@ -415,9 +415,10 @@ class InputContext(object):
|
||||
# Base classes for all distribution strategies.
|
||||
|
||||
|
||||
# Base class for v1 Strategy and v2 Strategy classes. For API's specific to
|
||||
# v1/v2 Strategy, add to implementing classes of StrategyBase.
|
||||
# pylint: disable=line-too-long
|
||||
@tf_export("distribute.Strategy", v1=[])
|
||||
class Strategy(object):
|
||||
class StrategyBase(object):
|
||||
"""A state & compute distribution policy on a list of devices.
|
||||
|
||||
See [the guide](https://www.tensorflow.org/guide/distributed_training)
|
||||
@ -1008,9 +1009,176 @@ class Strategy(object):
|
||||
raise RuntimeError("Must only deepcopy DistributionStrategy.")
|
||||
|
||||
|
||||
@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring
|
||||
class Strategy(StrategyBase):
|
||||
|
||||
__doc__ = StrategyBase.__doc__
|
||||
|
||||
def experimental_assign_to_logical_device(self, tensor, logical_device_id):
|
||||
"""Adds annotation that `tensor` will be assigned to a logical device.
|
||||
|
||||
NOTE: This API is only supported in TPUStrategy for now.
|
||||
This adds an annotation to `tensor` specifying that operations on
|
||||
`tensor` will be invoked on logical core device id `logical_device_id`.
|
||||
When model parallelism is used, the default behavior is that all ops
|
||||
are placed on zero-th logical device.
|
||||
|
||||
```python
|
||||
|
||||
# Initializing TPU system with 2 logical devices and 4 replicas.
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
||||
topology,
|
||||
computation_shape=[1, 1, 2],
|
||||
num_replicas=4)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
iterator = iter(inputs)
|
||||
|
||||
@tf.function()
|
||||
def step_fn(inputs):
|
||||
output = tf.add(inputs, inputs)
|
||||
|
||||
// Add operation will be executed on logical device 0.
|
||||
output = strategy.experimental_assign_to_logical_device(output, 0)
|
||||
return output
|
||||
|
||||
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
|
||||
```
|
||||
|
||||
Args:
|
||||
tensor: Input tensor to annotate.
|
||||
logical_device_id: Id of the logical core to which the tensor will be
|
||||
assigned.
|
||||
|
||||
Raises:
|
||||
ValueError: The logical device id presented is not consistent with total
|
||||
number of partitions specified by the device assignment.
|
||||
|
||||
Returns:
|
||||
Annotated tensor with idential value as `tensor`.
|
||||
"""
|
||||
return self._extended._experimental_assign_to_logical_device( # pylint: disable=protected-access
|
||||
tensor, logical_device_id)
|
||||
|
||||
def experimental_split_to_logical_devices(self, tensor, partition_dimensions):
|
||||
"""Adds annotation that `tensor` will be split across logical devices.
|
||||
|
||||
NOTE: This API is only supported in TPUStrategy for now.
|
||||
This adds an annotation to tensor `tensor` specifying that operations on
|
||||
`tensor` will be be split among multiple logical devices. Tensor `tensor`
|
||||
will be split across dimensions specified by `partition_dimensions`.
|
||||
The dimensions of `tensor` must be divisible by corresponding value in
|
||||
`partition_dimensions`.
|
||||
|
||||
For example, for system with 8 logical devices, if `tensor` is an image
|
||||
tensor with shape (batch_size, width, height, channel) and
|
||||
`partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split
|
||||
2 in width dimension and 4 way in height dimension and the split
|
||||
tensor values will be fed into 8 logical devices.
|
||||
|
||||
```python
|
||||
# Initializing TPU system with 8 logical devices and 1 replica.
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
||||
topology,
|
||||
computation_shape=[2, 2, 2],
|
||||
num_replicas=1)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
|
||||
iterator = iter(inputs)
|
||||
|
||||
@tf.function()
|
||||
def step_fn(inputs):
|
||||
inputs = strategy.experimental_split_to_logical_devices(
|
||||
inputs, [1, 2, 4, 1])
|
||||
|
||||
// model() function will be executed on 8 logical devices with `inputs`
|
||||
// split 2 * 4 ways.
|
||||
output = model(inputs)
|
||||
return output
|
||||
|
||||
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
|
||||
```
|
||||
Args:
|
||||
tensor: Input tensor to annotate.
|
||||
partition_dimensions: An unnested list of integers with the size equal to
|
||||
rank of `tensor` specifying how `tensor` will be partitioned. The
|
||||
product of all elements in `partition_dimensions` must be equal to the
|
||||
total number of logical devices per replica.
|
||||
|
||||
Raises:
|
||||
ValueError: 1) If the size of partition_dimensions does not equal to rank
|
||||
of `tensor` or 2) if product of elements of `partition_dimensions` does
|
||||
not match the number of logical devices per replica defined by the
|
||||
implementing DistributionStrategy's device specification or
|
||||
3) if a known size of `tensor` is not divisible by corresponding
|
||||
value in `partition_dimensions`.
|
||||
|
||||
Returns:
|
||||
Annotated tensor with idential value as `tensor`.
|
||||
"""
|
||||
return self._extended._experimental_split_to_logical_devices( # pylint: disable=protected-access
|
||||
tensor, partition_dimensions)
|
||||
|
||||
def experimental_replicate_to_logical_devices(self, tensor):
|
||||
"""Adds annotation that `tensor` will be replicated to all logical devices.
|
||||
|
||||
NOTE: This API is only supported in TPUStrategy for now.
|
||||
This adds an annotation to tensor `tensor` specifying that operations on
|
||||
`tensor` will be invoked on all logical devices.
|
||||
|
||||
```python
|
||||
# Initializing TPU system with 2 logical devices and 4 replicas.
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
||||
topology,
|
||||
computation_shape=[1, 1, 2],
|
||||
num_replicas=4)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(
|
||||
resolver, device_assignment=device_assignment)
|
||||
|
||||
iterator = iter(inputs)
|
||||
|
||||
@tf.function()
|
||||
def step_fn(inputs):
|
||||
images, labels = inputs
|
||||
images = strategy.experimental_split_to_logical_devices(
|
||||
inputs, [1, 2, 4, 1])
|
||||
|
||||
// model() function will be executed on 8 logical devices with `inputs`
|
||||
// split 2 * 4 ways.
|
||||
output = model(inputs)
|
||||
|
||||
// For loss calculation, all logical devices share the same logits
|
||||
// and labels.
|
||||
labels = strategy.experimental_replicate_to_logical_devices(labels)
|
||||
output = strategy.experimental_replicate_to_logical_devices(output)
|
||||
loss = loss_fn(labels, output)
|
||||
|
||||
return loss
|
||||
|
||||
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
|
||||
```
|
||||
Args:
|
||||
tensor: Input tensor to annotate.
|
||||
|
||||
Returns:
|
||||
Annotated tensor with idential value as `tensor`.
|
||||
"""
|
||||
return self._extended._experimental_replicate_to_logical_devices(tensor) # pylint: disable=protected-access
|
||||
|
||||
|
||||
# TF v1.x version has additional deprecated APIs
|
||||
@tf_export(v1=["distribute.Strategy"])
|
||||
class StrategyV1(Strategy):
|
||||
class StrategyV1(StrategyBase):
|
||||
"""A list of devices with a state & compute distribution policy.
|
||||
|
||||
See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
|
||||
@ -1158,7 +1326,7 @@ class StrategyV1(Strategy):
|
||||
def reduce(self, reduce_op, value, axis=None):
|
||||
return super(StrategyV1, self).reduce(reduce_op, value, axis)
|
||||
|
||||
reduce.__doc__ = Strategy.reduce.__doc__
|
||||
reduce.__doc__ = StrategyBase.reduce.__doc__
|
||||
|
||||
def update_config_proto(self, config_proto):
|
||||
"""Returns a copy of `config_proto` modified for use with this strategy.
|
||||
@ -1487,6 +1655,19 @@ class StrategyExtendedV2(object):
|
||||
"""Validate `colocate_with_variable` argument to `colocate_vars_with`."""
|
||||
pass
|
||||
|
||||
def _experimental_assign_to_logical_device(self, tensor, logical_device_id):
|
||||
raise NotImplementedError("This method should be overriden by "
|
||||
"sub-classes which support model parallelism.")
|
||||
|
||||
def _experimental_split_to_logical_devices(self, tensor,
|
||||
partition_dimensions):
|
||||
raise NotImplementedError("This method should be overriden by "
|
||||
"sub-classes which support model parallelism.")
|
||||
|
||||
def _experimental_replicate_to_logical_devices(self, tensor):
|
||||
raise NotImplementedError("This method should be overriden by "
|
||||
"sub-classes which support model parallelism.")
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
|
@ -25,6 +25,7 @@ import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
@ -500,6 +501,56 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
finally:
|
||||
self._logical_device_stack.pop()
|
||||
|
||||
def _experimental_assign_to_logical_device(self, tensor, logical_device_id):
|
||||
"""See `DistributionStrategy.experimental_assign_to_logical_device`."""
|
||||
num_logical_devices_per_replica = self._tpu_devices.shape[1]
|
||||
if (logical_device_id < 0 or
|
||||
logical_device_id >= num_logical_devices_per_replica):
|
||||
raise ValueError("`logical_core_id` to assign must be lower then total "
|
||||
"number of logical devices per replica. Received "
|
||||
"logical device id {} but there are only total of {} "
|
||||
"logical devices in replica.".format(
|
||||
logical_device_id, num_logical_devices_per_replica))
|
||||
return xla_sharding.assign_device(tensor, logical_device_id)
|
||||
|
||||
def _experimental_split_to_logical_devices(self, tensor,
|
||||
partition_dimensions):
|
||||
"""See `DistributionStrategy.experimental_split_to_logical_devices`."""
|
||||
num_logical_devices_per_replica = self._tpu_devices.shape[1]
|
||||
num_partition_splits = np.prod(partition_dimensions)
|
||||
input_shape = tensor.shape
|
||||
tensor_rank = len(input_shape)
|
||||
|
||||
if tensor_rank != len(partition_dimensions):
|
||||
raise ValueError("Length of `partition_dimensions` ({}) must be "
|
||||
"equal to the rank of `x` ({}).".format(
|
||||
len(partition_dimensions), tensor_rank))
|
||||
|
||||
for dim_index, dim_size in enumerate(input_shape):
|
||||
if dim_size is None:
|
||||
continue
|
||||
|
||||
split_size = partition_dimensions[dim_index]
|
||||
if dim_size % split_size != 0:
|
||||
raise ValueError("Tensor shape at dimension ({}) must be "
|
||||
"divisible by corresponding value specified "
|
||||
"by `partition_dimensions` ({}).".format(
|
||||
dim_index, split_size))
|
||||
|
||||
if num_partition_splits != num_logical_devices_per_replica:
|
||||
raise ValueError("Number of logical devices ({}) does not match the "
|
||||
"number of partition splits specified ({}).".format(
|
||||
num_logical_devices_per_replica,
|
||||
num_partition_splits))
|
||||
|
||||
tile_assignment = np.arange(num_partition_splits).reshape(
|
||||
partition_dimensions)
|
||||
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
||||
|
||||
def _experimental_replicate_to_logical_devices(self, tensor):
|
||||
"""See `DistributionStrategy.experimental_replicate_to_logical_devices`."""
|
||||
return xla_sharding.replicate(tensor, use_sharding_op=True)
|
||||
|
||||
def _experimental_initialize_system(self):
|
||||
"""Experimental method added to be used by Estimator.
|
||||
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.distribute.MirroredStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.mirrored_strategy.MirroredStrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.distribute.OneDeviceStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.one_device_strategy.OneDeviceStrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.distribute.Strategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.distribute.experimental.CentralStorageStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.central_storage_strategy.CentralStorageStrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.distribute.experimental.MultiWorkerMirroredStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.distribute.experimental.ParameterServerStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy.ParameterServerStrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.distribute.experimental.TPUStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.tpu_strategy.TPUStrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.distribute.MirroredStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.mirrored_strategy.MirroredStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
@ -23,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_assign_to_logical_device"
|
||||
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -39,6 +44,10 @@ tf_class {
|
||||
name: "experimental_make_numpy_dataset"
|
||||
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_replicate_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -47,6 +56,10 @@ tf_class {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.distribute.OneDeviceStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.one_device_strategy.OneDeviceStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
@ -23,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_assign_to_logical_device"
|
||||
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -39,6 +44,10 @@ tf_class {
|
||||
name: "experimental_make_numpy_dataset"
|
||||
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_replicate_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -47,6 +56,10 @@ tf_class {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.distribute.Strategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
@ -22,6 +23,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_assign_to_logical_device"
|
||||
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -38,6 +43,10 @@ tf_class {
|
||||
name: "experimental_make_numpy_dataset"
|
||||
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_replicate_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -46,6 +55,10 @@ tf_class {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.distribute.experimental.CentralStorageStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.central_storage_strategy.CentralStorageStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
@ -23,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_assign_to_logical_device"
|
||||
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -39,6 +44,10 @@ tf_class {
|
||||
name: "experimental_make_numpy_dataset"
|
||||
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_replicate_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -47,6 +56,10 @@ tf_class {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.distribute.experimental.MultiWorkerMirroredStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
@ -23,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_assign_to_logical_device"
|
||||
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -39,6 +44,10 @@ tf_class {
|
||||
name: "experimental_make_numpy_dataset"
|
||||
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_replicate_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -47,6 +56,10 @@ tf_class {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.distribute.experimental.ParameterServerStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.parameter_server_strategy.ParameterServerStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
@ -23,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_assign_to_logical_device"
|
||||
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -39,6 +44,10 @@ tf_class {
|
||||
name: "experimental_make_numpy_dataset"
|
||||
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_replicate_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -47,6 +56,10 @@ tf_class {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.distribute.experimental.TPUStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.tpu_strategy.TPUStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "extended"
|
||||
@ -23,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_assign_to_logical_device"
|
||||
argspec: "args=[\'self\', \'tensor\', \'logical_device_id\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -39,6 +44,10 @@ tf_class {
|
||||
name: "experimental_make_numpy_dataset"
|
||||
argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_replicate_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -47,6 +56,10 @@ tf_class {
|
||||
name: "experimental_run_v2"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_split_to_logical_devices"
|
||||
argspec: "args=[\'self\', \'tensor\', \'partition_dimensions\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user