Move distribute.py, distribution_strategy_context.py, and device_util.py

from training/ to distribute/.

PiperOrigin-RevId: 222761376
This commit is contained in:
A. Unique TensorFlower 2018-11-25 18:01:45 -08:00 committed by TensorFlower Gardener
parent b8ac6cb249
commit e14e62133c
36 changed files with 2118 additions and 2023 deletions

View File

@ -27,13 +27,13 @@ cuda_py_test(
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:device_util",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:device_util",
"//tensorflow/python/distribute:values", "//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
@ -49,7 +49,7 @@ py_library(
srcs = ["mirrored_strategy.py"], srcs = ["mirrored_strategy.py"],
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/python:distribute", "//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:values", "//tensorflow/python/distribute:values",
], ],
@ -114,10 +114,10 @@ py_library(
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:distribute",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values", "//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
@ -156,11 +156,11 @@ py_library(
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:layers", "//tensorflow/python:layers",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:backprop", "//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
@ -181,10 +181,10 @@ py_library(
":tpu_strategy", ":tpu_strategy",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/optimizer_v2:training", "//tensorflow/contrib/optimizer_v2:training",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
@ -229,11 +229,11 @@ cuda_py_test(
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:layers", "//tensorflow/python:layers",
"//tensorflow/python:state_ops", "//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:values", "//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",

View File

@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -31,7 +32,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops from tensorflow.python.ops import collective_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
# TODO(yuefengz): support in-graph replication. # TODO(yuefengz): support in-graph replication.

View File

@ -53,11 +53,11 @@ from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2
from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2
from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.training import adagrad from tensorflow.python.training import adagrad
from tensorflow.python.training import adam from tensorflow.python.training import adam
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop from tensorflow.python.training import rmsprop
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect

View File

@ -29,6 +29,7 @@ from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values as value_lib from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -37,7 +38,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.training import device_util
def _make_per_replica(values, devices, regroup=False): def _make_per_replica(values, devices, regroup=False):

View File

@ -22,13 +22,13 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import combinations
from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import values as value_lib from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.training import device_util
class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):

View File

@ -28,6 +28,7 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.estimator import run_config from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import training from tensorflow.python.estimator import training
from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import dnn_linear_combined
@ -46,7 +47,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import distribution_strategy_context as ds_context
class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase):

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import functools import functools
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.training import distribute as distribute_lib
# pylint: disable=protected-access,invalid-name # pylint: disable=protected-access,invalid-name

View File

@ -29,6 +29,8 @@ from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
@ -48,8 +50,6 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import gradient_descent from tensorflow.python.training import gradient_descent
from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib

View File

@ -20,13 +20,13 @@ from __future__ import print_function
import six import six
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest from tensorflow.python.util import nest

View File

@ -22,6 +22,8 @@ import copy
from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -32,8 +34,6 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_setter from tensorflow.python.training import device_setter
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest from tensorflow.python.util import nest
_LOCAL_CPU = "/device:CPU:0" _LOCAL_CPU = "/device:CPU:0"

View File

@ -28,6 +28,8 @@ from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
@ -46,8 +48,6 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import training_util from tensorflow.python.training import training_util
CHIEF = run_config.TaskType.CHIEF CHIEF = run_config.TaskType.CHIEF

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
@ -33,7 +34,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import optimizer from tensorflow.python.training import optimizer

View File

@ -29,6 +29,8 @@ from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -41,8 +43,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest from tensorflow.python.util import nest

View File

@ -25,6 +25,8 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import test from tensorflow.python.eager import test
@ -39,8 +41,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest from tensorflow.python.util import nest

View File

@ -3596,9 +3596,7 @@ py_library(
srcs = ["training/device_util.py"], srcs = ["training/device_util.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":device", "//tensorflow/python/distribute:device_util",
":framework_ops",
"//tensorflow/python/eager:context",
], ],
) )
@ -3610,35 +3608,7 @@ py_library(
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":array_ops", "//tensorflow/python/distribute:distribute_lib",
":constant_op",
":control_flow_ops",
":device_util",
":dtypes",
":framework_ops",
":platform",
":resource_variable_ops",
":state_ops",
":util",
":variable_scope",
"//tensorflow/python/data",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/ops/losses",
"//tensorflow/tools/docs:doc_controls",
],
)
py_test(
name = "distribute_test",
size = "small",
srcs = ["training/distribute_test.py"],
srcs_version = "PY2AND3",
deps = [
":client_testlib",
":constant_op",
":distribute",
":dtypes",
":variable_scope",
], ],
) )
@ -4627,7 +4597,6 @@ cuda_py_tests(
"training/basic_loops_test.py", "training/basic_loops_test.py",
"training/coordinator_test.py", "training/coordinator_test.py",
"training/device_setter_test.py", "training/device_setter_test.py",
"training/device_util_test.py",
"training/ftrl_test.py", "training/ftrl_test.py",
"training/gradient_descent_test.py", "training/gradient_descent_test.py",
"training/learning_rate_decay_test.py", "training/learning_rate_decay_test.py",

View File

@ -50,6 +50,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":cross_device_utils", ":cross_device_utils",
":device_util",
":reduce_util", ":reduce_util",
":values", ":values",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
@ -58,8 +59,6 @@ py_library(
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops", "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"@six_archive//:six", "@six_archive//:six",
], ],
@ -83,6 +82,67 @@ py_library(
], ],
) )
py_library(
name = "device_util",
srcs = ["device_util.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:device",
"//tensorflow/python:framework_ops",
"//tensorflow/python/eager:context",
],
)
cuda_py_test(
name = "device_util_test",
srcs = ["device_util_test.py"],
additional_deps = [
":device_util",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
],
)
py_library(
name = "distribute_lib",
srcs = [
"distribute_lib.py",
"distribution_strategy_context.py",
],
srcs_version = "PY2AND3",
deps = [
":device_util",
":reduce_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data",
"//tensorflow/python/ops/losses",
"//tensorflow/tools/docs:doc_controls",
],
)
py_test(
name = "distribute_lib_test",
size = "small",
srcs = ["distribute_lib_test.py"],
srcs_version = "PY2AND3",
deps = [
":distribute_lib",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:variable_scope",
],
)
py_library( py_library(
name = "distribute_config", name = "distribute_config",
srcs = [ srcs = [
@ -144,6 +204,8 @@ py_library(
srcs = ["mirrored_strategy.py"], srcs = ["mirrored_strategy.py"],
deps = [ deps = [
":cross_device_ops", ":cross_device_ops",
":device_util",
":distribute_lib",
":multi_worker_util", ":multi_worker_util",
":reduce_util", ":reduce_util",
":shared_variable_creator", ":shared_variable_creator",
@ -153,8 +215,6 @@ py_library(
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:device", "//tensorflow/python:device",
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tensorflow",
@ -195,12 +255,12 @@ cuda_py_test(
additional_deps = [ additional_deps = [
":input_ops", ":input_ops",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:io_ops", "//tensorflow/python:io_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
tags = [ tags = [
@ -271,11 +331,11 @@ py_library(
name = "values", name = "values",
srcs = ["values.py"], srcs = ["values.py"],
deps = [ deps = [
":device_util",
":distribute_lib",
":input_ops", ":input_ops",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops", "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training", "//tensorflow/python:training",

View File

@ -23,6 +23,7 @@ import six
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib
from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values as value_lib from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -31,7 +32,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_util
def check_destinations(destinations): def check_destinations(destinations):

View File

@ -0,0 +1,97 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Device-related support functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
def canonicalize(d, default=None):
"""Canonicalize device string.
If d has missing components, the rest would be deduced from the `default`
argument or from '/replica:0/task:0/device:CPU:0'. For example:
If d = '/cpu:0', default='/job:worker/task:1', it returns
'/job:worker/replica:0/task:1/device:CPU:0'.
If d = '/cpu:0', default='/job:worker', it returns
'/job:worker/replica:0/task:0/device:CPU:0'.
If d = '/gpu:0', default=None, it returns
'/replica:0/task:0/device:GPU:0'.
Note: This uses "job:localhost" as the default if executing eagerly.
Args:
d: a device string.
default: a string for default device if d doesn't have all components.
Returns:
a canonicalized device string.
"""
d = tf_device.DeviceSpec.from_string(d)
assert d.device_type is None or d.device_type == d.device_type.upper(), (
"Device type '%s' must be all-caps." % (d.device_type,))
# Fill in missing device fields using defaults.
result = tf_device.DeviceSpec(
replica=0, task=0, device_type="CPU", device_index=0)
if context.executing_eagerly():
result.job = "localhost"
if default:
result.merge_from(tf_device.DeviceSpec.from_string(default))
result.merge_from(d)
return result.to_string()
def resolve(d):
"""Canonicalize `d` with current device as default."""
return canonicalize(d, default=current())
class _FakeNodeDef(object):
"""A fake NodeDef for _FakeOperation."""
def __init__(self):
self.op = ""
self.name = ""
class _FakeOperation(object):
"""A fake Operation object to pass to device functions."""
def __init__(self):
self.device = ""
self.type = ""
self.name = ""
self.node_def = _FakeNodeDef()
def _set_device(self, device):
self.device = ops._device_string(device) # pylint: disable=protected-access
def current():
"""Return a string (not canonicalized) for the current device."""
# TODO(josh11b): Work out how this function interacts with ops.colocate_with.
ctx = context.context()
if ctx.executing_eagerly():
d = ctx.device_name
else:
op = _FakeOperation()
ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access
d = op.device
return d

View File

@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import device_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import device_util
class DeviceUtilTest(test.TestCase): class DeviceUtilTest(test.TestCase):

File diff suppressed because it is too large Load Diff

View File

@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import distribution_strategy_context
class _TestReplicaContext(distribute_lib.ReplicaContext): class _TestReplicaContext(distribute_lib.ReplicaContext):

View File

@ -0,0 +1,236 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility to get distribution strategy related contexts."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# There is a circular dependency between this and `distribute` module. So we
# load it lazily to workaround this.
distribute_lib = LazyLoader(
"distribute_lib", globals(),
"tensorflow.python.distribute.distribute_lib")
# ------------------------------------------------------------------------------
# Internal API for setting the current thread mode as being either in a
# replica or cross-replica context for a particular distribution strategy.
class _ThreadMode(object):
def __init__(self, dist, cross, replica):
self.distribution_strategy = dist
self.cross_replica_context = cross
self.replica_context = replica
class _CrossReplicaThreadMode(_ThreadMode):
def __init__(self, distribution_strategy):
_ThreadMode.__init__(
self, distribution_strategy, distribution_strategy, None)
class _InReplicaThreadMode(_ThreadMode):
def __init__(self, replica_ctx):
_ThreadMode.__init__(
self, replica_ctx.distribution_strategy, None, replica_ctx)
def _push_per_thread_mode(context):
ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access
def _pop_per_thread_mode():
ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access
class _DefaultReplicaThreadMode(_ThreadMode):
"""Type of default value returned by `_get_per_thread_mode()`.
Used when the thread-local stack is empty.
"""
def __init__(self):
_ThreadMode.__init__(self, _get_default_distribution_strategy(), None,
_get_default_replica_context())
def _get_per_thread_mode():
try:
return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access
except (AttributeError, IndexError):
return _get_default_replica_mode()
# ------------------------------------------------------------------------------
# Public API for accessing the current thread mode
@tf_export("distribute.get_replica_context")
def get_replica_context():
"""Returns the current `tf.distribute.ReplicaContext` or `None`.
Returns `None` if in a cross-replica context.
Note that execution:
1. starts in the default (single-replica) replica context (this function
will return the default `ReplicaContext` object);
2. switches to cross-replica context (in which case this will return
`None`) when entering a `with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside
`extended.call_for_each_replica(fn, ...)`;
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
inside `merge_fn` you are back in the cross-replica context (and again
this function will return `None`).
Note that you can also go directly from step 1 to 4 to switch to a
cross-replica context for the default `tf.distribute.Strategy`. You may
also switch from the cross-replica context of 4 to a replica context by
calling `extended.call_for_each_replica()`, jumping back to step 3.
Most `tf.distribute.Strategy` methods may only be executed in
a cross-replica context, in a replica context you should use the
`ReplicaContext` API instead.
Returns:
The current `ReplicaContext` object when in a replica context scope,
else `None`.
Within a particular block, exactly one of these two things will be true:
* `get_replica_context()` returns non-`None`, or
* `tf.distribute.is_cross_replica_context()` returns True.
"""
return _get_per_thread_mode().replica_context
def get_cross_replica_context():
"""Returns the current tf.distribute.Strategy if in a cross-replica context.
DEPRECATED: Please use `in_cross_replica_context()` and
`get_distribution_strategy()` instead.
Note that execution:
1. starts in the default (single-replica) replica context;
2. switches to cross-replica context when entering a
`with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside
`call_for_each_replica(fn, ...)`;
4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
inside `merge_fn` you are back in the cross-replica context.
Note that you can also go directly from step 1 to 4 to switch to a
cross-replica context for the default `tf.distribute.Strategy`. You may
also switch from the cross-replica context of 4 to a replica context by
calling `call_for_each_replica()`, jumping back to step 3.
Most `tf.distribute.Strategy` methods may only be executed in
a cross-replica context.
Returns:
Returns the current `tf.distribute.Strategy` object in a cross-replica
context, or `None`.
Exactly one of `get_replica_context()` and `get_cross_replica_context()`
will return `None` in a particular block.
"""
return _get_per_thread_mode().cross_replica_context
@tf_export("distribute.in_cross_replica_context")
def in_cross_replica_context():
"""Returns True if in a cross-replica context.
See `tf.distribute.get_replica_context` for details.
Returns:
True if in a cross-replica context (`get_replica_context()` returns
`None`), or False if in a replica context (`get_replica_context()` returns
non-`None`).
"""
return _get_per_thread_mode().cross_replica_context is not None
@tf_export("distribute.get_strategy")
def get_distribution_strategy():
"""Returns the current `tf.distribute.Strategy` object.
Typically only used in a cross-replica context:
```
if tf.distribute.in_cross_replica_context():
strategy = tf.distribute.get_strategy()
...
```
Returns:
A `tf.distribute.Strategy` object. Inside a
`with distribution_strategy.scope()` block, it returns
`distribution_strategy`, otherwise it returns the default
(single-replica) `tf.distribute.Strategy` object.
"""
return _get_per_thread_mode().distribution_strategy
@tf_export("distribute.has_strategy")
def has_distribution_strategy():
"""Return if there is a current non-default `tf.distribute.Strategy`.
Returns:
True if inside a `with strategy.scope():`.
"""
return get_distribution_strategy() is not _get_default_distribution_strategy()
# ------------------------------------------------------------------------------
# Defaults that are used when no distribution strategy is explicitly created.
# We create them lazily in a function so that we can workaround the circular
# dependency on distribute_lib. See lazy loader at the top of this file.
_defaults = {
"distribution_strategy": None,
"replica_context": None,
"replica_mode": None
}
def _get_default_distribution_strategy():
if _defaults["distribution_strategy"] is None:
_defaults["distribution_strategy"] = (
distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access
return _defaults["distribution_strategy"]
def _get_default_replica_context():
if _defaults["replica_context"] is None:
_defaults["replica_context"] = distribute_lib.ReplicaContext(
_get_default_distribution_strategy(), replica_id_in_sync_group=0)
return _defaults["replica_context"]
def _get_default_replica_mode():
if _defaults["replica_mode"] is None:
_defaults["replica_mode"] = _DefaultReplicaThreadMode()
return _defaults["replica_mode"]

View File

@ -25,6 +25,8 @@ import threading
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import shared_variable_creator from tensorflow.python.distribute import shared_variable_creator
@ -40,8 +42,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.training import coordinator from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest from tensorflow.python.util import nest

View File

@ -30,6 +30,9 @@ import six
from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -42,9 +45,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest from tensorflow.python.util import nest

View File

@ -37,7 +37,7 @@ _TENSORFLOW_DOC_SOURCES = {
'app': DocSource(docstring_module_name='platform.app'), 'app': DocSource(docstring_module_name='platform.app'),
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'), 'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
'compat': DocSource(docstring_module_name='util.compat'), 'compat': DocSource(docstring_module_name='util.compat'),
'distribute': DocSource(docstring_module_name='training.distribute'), 'distribute': DocSource(docstring_module_name='distribute.distribute_lib'),
'distributions': DocSource( 'distributions': DocSource(
docstring_module_name='ops.distributions.distributions'), docstring_module_name='ops.distributions.distributions'),
'errors': DocSource(docstring_module_name='framework.errors'), 'errors': DocSource(docstring_module_name='framework.errors'),

View File

@ -12,86 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Device-related support functions.""" """Deprecated, please use ../distribute/device_util.py."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.eager import context # pylint: disable=wildcard-import
from tensorflow.python.framework import device as tf_device from tensorflow.python.distribute.device_util import *
from tensorflow.python.framework import ops
def canonicalize(d, default=None):
"""Canonicalize device string.
If d has missing components, the rest would be deduced from the `default`
argument or from '/replica:0/task:0/device:CPU:0'. For example:
If d = '/cpu:0', default='/job:worker/task:1', it returns
'/job:worker/replica:0/task:1/device:CPU:0'.
If d = '/cpu:0', default='/job:worker', it returns
'/job:worker/replica:0/task:0/device:CPU:0'.
If d = '/gpu:0', default=None, it returns
'/replica:0/task:0/device:GPU:0'.
Note: This uses "job:localhost" as the default if executing eagerly.
Args:
d: a device string.
default: a string for default device if d doesn't have all components.
Returns:
a canonicalized device string.
"""
d = tf_device.DeviceSpec.from_string(d)
assert d.device_type is None or d.device_type == d.device_type.upper(), (
"Device type '%s' must be all-caps." % (d.device_type,))
# Fill in missing device fields using defaults.
result = tf_device.DeviceSpec(
replica=0, task=0, device_type="CPU", device_index=0)
if context.executing_eagerly():
result.job = "localhost"
if default:
result.merge_from(tf_device.DeviceSpec.from_string(default))
result.merge_from(d)
return result.to_string()
def resolve(d):
"""Canonicalize `d` with current device as default."""
return canonicalize(d, default=current())
class _FakeNodeDef(object):
"""A fake NodeDef for _FakeOperation."""
def __init__(self):
self.op = ""
self.name = ""
class _FakeOperation(object):
"""A fake Operation object to pass to device functions."""
def __init__(self):
self.device = ""
self.type = ""
self.name = ""
self.node_def = _FakeNodeDef()
def _set_device(self, device):
self.device = ops._device_string(device) # pylint: disable=protected-access
def current():
"""Return a string (not canonicalized) for the current device."""
# TODO(josh11b): Work out how this function interacts with ops.colocate_with.
ctx = context.context()
if ctx.executing_eagerly():
d = ctx.device_name
else:
op = _FakeOperation()
ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access
d = op.device
return d

File diff suppressed because it is too large Load Diff

View File

@ -12,225 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Utility to get distribution strategy related contexts.""" """Deprecated, please use ../distribute/distribution_strategy_context.py."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import ops # pylint: disable=wildcard-import
from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.distribute.distribution_strategy_context import *
from tensorflow.python.util.tf_export import tf_export
# There is a circular dependency between this and `distribute` module. So we
# load it lazily to workaround this.
distribute_lib = LazyLoader(
"distribute_lib", globals(),
"tensorflow.python.training.distribute")
# ------------------------------------------------------------------------------
# Internal API for setting the current thread mode as being either in a
# replica or cross-replica context for a particular distribution strategy.
class _ThreadMode(object):
def __init__(self, dist, cross, replica):
self.distribution_strategy = dist
self.cross_replica_context = cross
self.replica_context = replica
class _CrossReplicaThreadMode(_ThreadMode):
def __init__(self, distribution_strategy):
_ThreadMode.__init__(
self, distribution_strategy, distribution_strategy, None)
class _InReplicaThreadMode(_ThreadMode):
def __init__(self, replica_ctx):
_ThreadMode.__init__(
self, replica_ctx.distribution_strategy, None, replica_ctx)
def _push_per_thread_mode(context):
ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access
def _pop_per_thread_mode():
ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access
class _DefaultReplicaThreadMode(_ThreadMode):
"""Type of default value returned by `_get_per_thread_mode()`.
Used when the thread-local stack is empty.
"""
def __init__(self):
_ThreadMode.__init__(self, _get_default_distribution_strategy(), None,
_get_default_replica_context())
def _get_per_thread_mode():
try:
return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access
except (AttributeError, IndexError):
return _get_default_replica_mode()
# ------------------------------------------------------------------------------
# Public API for accessing the current thread mode
@tf_export("distribute.get_replica_context")
def get_replica_context():
"""Returns the current `tf.distribute.ReplicaContext` or `None`.
Returns `None` if in a cross-replica context.
Note that execution:
1. starts in the default (single-replica) replica context (this function
will return the default `ReplicaContext` object);
2. switches to cross-replica context (in which case this will return
`None`) when entering a `with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside
`extended.call_for_each_replica(fn, ...)`;
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
inside `merge_fn` you are back in the cross-replica context (and again
this function will return `None`).
Note that you can also go directly from step 1 to 4 to switch to a
cross-replica context for the default `tf.distribute.Strategy`. You may
also switch from the cross-replica context of 4 to a replica context by
calling `extended.call_for_each_replica()`, jumping back to step 3.
Most `tf.distribute.Strategy` methods may only be executed in
a cross-replica context, in a replica context you should use the
`ReplicaContext` API instead.
Returns:
The current `ReplicaContext` object when in a replica context scope,
else `None`.
Within a particular block, exactly one of these two things will be true:
* `get_replica_context()` returns non-`None`, or
* `tf.distribute.is_cross_replica_context()` returns True.
"""
return _get_per_thread_mode().replica_context
def get_cross_replica_context():
"""Returns the current tf.distribute.Strategy if in a cross-replica context.
DEPRECATED: Please use `in_cross_replica_context()` and
`get_distribution_strategy()` instead.
Note that execution:
1. starts in the default (single-replica) replica context;
2. switches to cross-replica context when entering a
`with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside
`call_for_each_replica(fn, ...)`;
4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
inside `merge_fn` you are back in the cross-replica context.
Note that you can also go directly from step 1 to 4 to switch to a
cross-replica context for the default `tf.distribute.Strategy`. You may
also switch from the cross-replica context of 4 to a replica context by
calling `call_for_each_replica()`, jumping back to step 3.
Most `tf.distribute.Strategy` methods may only be executed in
a cross-replica context.
Returns:
Returns the current `tf.distribute.Strategy` object in a cross-replica
context, or `None`.
Exactly one of `get_replica_context()` and `get_cross_replica_context()`
will return `None` in a particular block.
"""
return _get_per_thread_mode().cross_replica_context
@tf_export("distribute.in_cross_replica_context")
def in_cross_replica_context():
"""Returns True if in a cross-replica context.
See `tf.distribute.get_replica_context` for details.
Returns:
True if in a cross-replica context (`get_replica_context()` returns
`None`), or False if in a replica context (`get_replica_context()` returns
non-`None`).
"""
return _get_per_thread_mode().cross_replica_context is not None
@tf_export("distribute.get_strategy")
def get_distribution_strategy():
"""Returns the current `tf.distribute.Strategy` object.
Typically only used in a cross-replica context:
```
if tf.distribute.in_cross_replica_context():
strategy = tf.distribute.get_strategy()
...
```
Returns:
A `tf.distribute.Strategy` object. Inside a
`with distribution_strategy.scope()` block, it returns
`distribution_strategy`, otherwise it returns the default
(single-replica) `tf.distribute.Strategy` object.
"""
return _get_per_thread_mode().distribution_strategy
@tf_export("distribute.has_strategy")
def has_distribution_strategy():
"""Return if there is a current non-default `tf.distribute.Strategy`.
Returns:
True if inside a `with strategy.scope():`.
"""
return get_distribution_strategy() is not _get_default_distribution_strategy()
# ------------------------------------------------------------------------------
# Defaults that are used when no distribution strategy is explicitly created.
# We create them lazily in a function so that we can workaround the circular
# dependency on distribute_lib. See lazy loader at the top of this file.
_defaults = {
"distribution_strategy": None,
"replica_context": None,
"replica_mode": None
}
def _get_default_distribution_strategy():
if _defaults["distribution_strategy"] is None:
_defaults["distribution_strategy"] = (
distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access
return _defaults["distribution_strategy"]
def _get_default_replica_context():
if _defaults["replica_context"] is None:
_defaults["replica_context"] = distribute_lib.ReplicaContext(
_get_default_distribution_strategy(), replica_id_in_sync_group=0)
return _defaults["replica_context"]
def _get_default_replica_mode():
if _defaults["replica_mode"] is None:
_defaults["replica_mode"] = _DefaultReplicaThreadMode()
return _defaults["replica_mode"]

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.InputContext" path: "tensorflow.distribute.InputContext"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputContext\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "input_pipeline_id" name: "input_pipeline_id"

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.ReplicaContext" path: "tensorflow.distribute.ReplicaContext"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "devices" name: "devices"

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.StrategyExtended" path: "tensorflow.distribute.StrategyExtended"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.DistributionStrategyExtended\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "experimental_between_graph" name: "experimental_between_graph"

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.Strategy" path: "tensorflow.distribute.Strategy"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.DistributionStrategy\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "between_graph" name: "between_graph"

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.InputContext" path: "tensorflow.distribute.InputContext"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputContext\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "input_pipeline_id" name: "input_pipeline_id"

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.ReplicaContext" path: "tensorflow.distribute.ReplicaContext"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "devices" name: "devices"

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.StrategyExtended" path: "tensorflow.distribute.StrategyExtended"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.DistributionStrategyExtended\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "experimental_between_graph" name: "experimental_between_graph"

View File

@ -1,6 +1,6 @@
path: "tensorflow.distribute.Strategy" path: "tensorflow.distribute.Strategy"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.DistributionStrategy\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "between_graph" name: "between_graph"