1901 lines
80 KiB
Python
1901 lines
80 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for variable store."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import gc
|
|
import threading
|
|
|
|
import numpy
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import function
|
|
from tensorflow.python.eager import wrap_function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.layers import core as core_layers
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import state_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variables as variables_lib
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.util import compat
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
def run_inside_wrap_function_in_eager_mode(graph_function):
|
|
"""Decorator to execute the same graph code in eager and graph modes.
|
|
|
|
In graph mode, we just execute the graph_function passed as argument. In eager
|
|
mode, we wrap the function using wrap_function and then execute the wrapped
|
|
result.
|
|
|
|
Args:
|
|
graph_function: python function containing graph code to be wrapped
|
|
|
|
Returns:
|
|
decorated function
|
|
"""
|
|
def wrap_and_execute(self):
|
|
if context.executing_eagerly():
|
|
wrapped = wrap_function.wrap_function(graph_function, [self])
|
|
# use the wrapped graph function
|
|
wrapped()
|
|
else:
|
|
# use the original function
|
|
graph_function(self)
|
|
return wrap_and_execute
|
|
|
|
|
|
class VariableScopeTest(test.TestCase):
|
|
|
|
def tearDown(self):
|
|
gc.collect()
|
|
# This will only contain uncollectable garbage, i.e. reference cycles
|
|
# involving objects with __del__ defined.
|
|
self.assertEqual(0, len(gc.garbage))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testGetVar(self):
|
|
vs = variable_scope._get_default_variable_store()
|
|
v = vs.get_variable("v", [1])
|
|
v1 = vs.get_variable("v", [1])
|
|
self.assertIs(v, v1)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testResource(self):
|
|
vs = variable_scope._get_default_variable_store()
|
|
v1 = vs.get_variable("v", [1], use_resource=True)
|
|
self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testNameExists(self):
|
|
vs = variable_scope._get_default_variable_store()
|
|
# No check by default, so we can both create and get existing names.
|
|
v = vs.get_variable("v", [1])
|
|
v1 = vs.get_variable("v", [1])
|
|
self.assertIs(v, v1)
|
|
|
|
# When reuse is False, we fail when variables are already there.
|
|
vs.get_variable("w", [1], reuse=False) # That's ok.
|
|
with self.assertRaises(ValueError):
|
|
vs.get_variable("v", [1], reuse=False) # That fails.
|
|
# When reuse is True, we fail when variables are new.
|
|
vs.get_variable("v", [1], reuse=True) # That's ok.
|
|
with self.assertRaises(ValueError):
|
|
vs.get_variable("u", [1], reuse=True) # That fails.
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testNamelessStore(self):
|
|
vs = variable_scope._get_default_variable_store()
|
|
vs.get_variable("v1", [2])
|
|
vs.get_variable("v2", [2])
|
|
expected_names = ["%s:0" % name for name in ["v1", "v2"]]
|
|
self.assertEqual(
|
|
set(expected_names), set(v.name for v in vs._vars.values()))
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# TypeError: Expected tf.group() expected Tensor arguments not 'None' with
|
|
# type '<type 'NoneType'>'
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testVarScopeInitializer(self):
|
|
init = init_ops.constant_initializer(0.3)
|
|
with variable_scope.variable_scope("tower0") as tower:
|
|
with variable_scope.variable_scope("foo", initializer=init):
|
|
v = variable_scope.get_variable("v", [])
|
|
self.evaluate(variables_lib.variables_initializer([v]))
|
|
self.assertAllClose(self.evaluate(v.value()), 0.3)
|
|
with variable_scope.variable_scope(tower, initializer=init):
|
|
w = variable_scope.get_variable("w", [])
|
|
self.evaluate(variables_lib.variables_initializer([w]))
|
|
self.assertAllClose(self.evaluate(w.value()), 0.3)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeConstraint(self):
|
|
constraint = lambda x: 0. * x
|
|
with variable_scope.variable_scope("tower1") as tower:
|
|
with variable_scope.variable_scope("foo", constraint=constraint):
|
|
v = variable_scope.get_variable("v", [])
|
|
self.assertEqual(v.constraint, constraint)
|
|
with variable_scope.variable_scope(tower, constraint=constraint):
|
|
w = variable_scope.get_variable("w", [])
|
|
self.assertEqual(w.constraint, constraint)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeNestingError(self):
|
|
with variable_scope.variable_scope("aa"):
|
|
scope = variable_scope.variable_scope("bb")
|
|
scope.__enter__()
|
|
with variable_scope.variable_scope("cc"):
|
|
with self.assertRaises(RuntimeError):
|
|
scope.__exit__(None, None, None)
|
|
scope.__exit__(None, None, None)
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string>
|
|
# has invalid type <class '...ResourceVariable'>, must be a string or Tensor.
|
|
# (Can not convert a ResourceVariable into a Tensor or Operation.)
|
|
@test_util.run_deprecated_v1
|
|
def testStringDefaultInitializer(self):
|
|
with self.cached_session():
|
|
v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string)
|
|
variables_lib.global_variables_initializer().run()
|
|
self.assertAllEqual(compat.as_bytes(self.evaluate(v)), b"")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeDType(self):
|
|
with variable_scope.variable_scope("tower2") as tower:
|
|
with variable_scope.variable_scope("foo", dtype=dtypes.float16):
|
|
v = variable_scope.get_variable("v", [])
|
|
self.assertEqual(v.dtype.base_dtype, dtypes.float16)
|
|
with variable_scope.variable_scope(tower, dtype=dtypes.float16):
|
|
w = variable_scope.get_variable("w", [])
|
|
self.assertEqual(w.dtype.base_dtype, dtypes.float16)
|
|
|
|
def testGetVariableInGraphNestedUnderEagerContext(self):
|
|
with context.eager_mode():
|
|
|
|
@function.defun
|
|
def f():
|
|
v = variable_scope.get_variable("should_be_resource", [])
|
|
self.assertEqual(type(v), resource_variable_ops.ResourceVariable)
|
|
|
|
f()
|
|
|
|
def testEagerVariableStore(self):
|
|
with context.eager_mode():
|
|
store = variable_scope.EagerVariableStore()
|
|
with store.as_default():
|
|
v = variable_scope.get_variable("v", shape=(), trainable=True)
|
|
w = variable_scope.get_variable("w", shape=(), trainable=False)
|
|
|
|
self.assertTrue(v in store.variables())
|
|
self.assertTrue(w in store.variables())
|
|
self.assertTrue(v in store.trainable_variables())
|
|
self.assertFalse(w in store.trainable_variables())
|
|
self.assertFalse(v in store.non_trainable_variables())
|
|
self.assertTrue(w in store.non_trainable_variables())
|
|
|
|
# Test copying.
|
|
new_store = store.copy()
|
|
with new_store.as_default():
|
|
new_v = variable_scope.get_variable("v")
|
|
new_w = variable_scope.get_variable("w")
|
|
self.assertEqual(new_v.numpy(), v.numpy())
|
|
self.assertEqual(new_w.numpy(), w.numpy())
|
|
self.assertTrue(new_v in new_store.variables())
|
|
self.assertTrue(new_w in new_store.variables())
|
|
self.assertTrue(new_v in new_store.trainable_variables())
|
|
self.assertFalse(new_w in new_store.trainable_variables())
|
|
self.assertFalse(new_v in new_store.non_trainable_variables())
|
|
self.assertTrue(new_w in new_store.non_trainable_variables())
|
|
|
|
# Check that variables are separate instances.
|
|
for v in store.variables():
|
|
v.assign(-1)
|
|
for v in new_store.variables():
|
|
v.assign(1)
|
|
for v in store.variables():
|
|
self.assertEqual(v.numpy(), -1)
|
|
for v in new_store.variables():
|
|
self.assertEqual(v.numpy(), 1)
|
|
|
|
def testEagerVariableStoreWithEagerDefun(self):
|
|
with context.eager_mode():
|
|
|
|
@function.defun
|
|
def f():
|
|
x = constant_op.constant([[2.0]])
|
|
d1 = core_layers.Dense(
|
|
1, name="my_dense", kernel_initializer=init_ops.ones_initializer())
|
|
_ = d1(x) # create variables
|
|
self.assertEqual(len(d1.variables), 2)
|
|
v1, v2 = d1.variables
|
|
d2 = core_layers.Dense(
|
|
1,
|
|
name="my_dense",
|
|
kernel_initializer=init_ops.ones_initializer(),
|
|
_reuse=True)
|
|
_ = d2(x)
|
|
self.assertEqual(len(d2.variables), 2)
|
|
v3, v4 = d2.variables
|
|
self.assertIs(v1, v3)
|
|
self.assertIs(v2, v4)
|
|
|
|
f()
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testEagerVariablesStoreAddsToCollections(self):
|
|
store = variable_scope.EagerVariableStore()
|
|
with store.as_default():
|
|
trainable = variable_scope.get_variable("v1", [], trainable=True)
|
|
not_trainable = variable_scope.get_variable("v2", [], trainable=False)
|
|
concat = variable_scope.get_variable(
|
|
"v3", [], collections=[ops.GraphKeys.CONCATENATED_VARIABLES])
|
|
self.assertEqual(
|
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES),
|
|
[trainable, not_trainable])
|
|
self.assertEqual(
|
|
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
|
|
[trainable, concat])
|
|
self.assertEqual(
|
|
ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat])
|
|
|
|
def testEagerVariablesOutsideStoreNotAddedToCollections(self):
|
|
with context.eager_mode():
|
|
variable_scope.get_variable("v1", [], trainable=True)
|
|
variable_scope.get_variable("v2", [], trainable=False)
|
|
self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
|
self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# TypeError: Expected tf.group() expected Tensor arguments not 'None' with
|
|
# type '<type 'NoneType'>'.
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testInitFromNonTensorValue(self):
|
|
v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
|
|
self.evaluate(variables_lib.variables_initializer([v]))
|
|
self.assertAllClose(self.evaluate(v.value()), 4)
|
|
|
|
w = variable_scope.get_variable(
|
|
"w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
|
|
self.evaluate(variables_lib.variables_initializer([w]))
|
|
self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
|
|
|
|
# A quirk to be revisited?
|
|
error = ValueError if context.executing_eagerly() else TypeError
|
|
with self.assertRaises(error):
|
|
variable_scope.get_variable("x4", initializer={})
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# InvalidArgumentError=: You must feed a value for placeholder tensor
|
|
# 'ReadVariableOp/resource' with dtype resource
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testInitFromNonInitializer(self):
|
|
# Test various dtypes with zeros initializer as following:
|
|
types = [
|
|
dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32,
|
|
dtypes.int64, dtypes.bool
|
|
]
|
|
|
|
# Use different variable_name to distinguish various dtypes
|
|
for (i, dtype) in enumerate(types):
|
|
x = variable_scope.get_variable(
|
|
name="xx%d" % i, shape=(3, 4), dtype=dtype)
|
|
y = variable_scope.get_variable(
|
|
name="yy%d" % i,
|
|
shape=(3, 4),
|
|
dtype=dtype,
|
|
initializer=init_ops.zeros_initializer(dtype=dtype))
|
|
|
|
self.evaluate(variables_lib.global_variables_initializer())
|
|
self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value()))
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# InvalidArgumentError: /job:moo/replica:0/task:0/device:CPU:0 unknown device.
|
|
@test_util.run_deprecated_v1
|
|
def testVarScopeCachingDevice(self):
|
|
with self.cached_session():
|
|
caching_device = "/job:moo"
|
|
with variable_scope.variable_scope("tower"):
|
|
with variable_scope.variable_scope(
|
|
"caching", caching_device=caching_device):
|
|
v = variable_scope.get_variable("v", [])
|
|
self.assertTrue(v.value().device.startswith(caching_device))
|
|
|
|
with variable_scope.variable_scope("child"):
|
|
v2 = variable_scope.get_variable("v", [])
|
|
self.assertTrue(v2.value().device.startswith(caching_device))
|
|
|
|
with variable_scope.variable_scope("not_cached", caching_device=""):
|
|
v2_not_cached = variable_scope.get_variable("v", [])
|
|
self.assertFalse(
|
|
v2_not_cached.value().device.startswith(caching_device))
|
|
|
|
with variable_scope.variable_scope(
|
|
"not_cached_identity_device",
|
|
caching_device=lambda op: op.device):
|
|
v2_identity_device = variable_scope.get_variable("v", [])
|
|
self.assertFalse(
|
|
v2_identity_device.value().device.startswith(caching_device))
|
|
|
|
with variable_scope.variable_scope("we_will_do_it_live") as vs_live:
|
|
vs_live.set_caching_device("/job:live")
|
|
v_live = variable_scope.get_variable("v", [])
|
|
self.assertTrue(v_live.value().device.startswith("/job:live"))
|
|
|
|
v_tower = variable_scope.get_variable("v", [])
|
|
self.assertFalse(v_tower.value().device.startswith(caching_device))
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# AttributeError: Tensor.name is meaningless when eager execution is enabled.
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testVarScopeRegularizer(self):
|
|
init = init_ops.constant_initializer(0.3)
|
|
|
|
def regularizer1(v):
|
|
return math_ops.reduce_mean(v) + 0.1
|
|
|
|
def regularizer2(v):
|
|
return math_ops.reduce_mean(v) + 0.2
|
|
|
|
with variable_scope.variable_scope(
|
|
"tower3", regularizer=regularizer1) as tower:
|
|
with variable_scope.variable_scope("foo", initializer=init):
|
|
v = variable_scope.get_variable("v", [])
|
|
self.evaluate(variables_lib.variables_initializer([v]))
|
|
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
|
self.assertEqual(1, len(losses))
|
|
self.assertAllClose(self.evaluate(losses[0]), 0.4)
|
|
with variable_scope.variable_scope(tower, initializer=init) as vs:
|
|
u = variable_scope.get_variable("u", [])
|
|
vs.set_regularizer(regularizer2)
|
|
w = variable_scope.get_variable("w", [])
|
|
# Next 3 variable not regularized to test disabling regularization.
|
|
x = variable_scope.get_variable(
|
|
"x", [], regularizer=variable_scope.no_regularizer)
|
|
with variable_scope.variable_scope(
|
|
"baz", regularizer=variable_scope.no_regularizer):
|
|
y = variable_scope.get_variable("y", [])
|
|
vs.set_regularizer(variable_scope.no_regularizer)
|
|
z = variable_scope.get_variable("z", [])
|
|
# Check results.
|
|
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
|
self.assertEqual(3, len(losses))
|
|
self.evaluate(variables_lib.variables_initializer([u, w, x, y, z]))
|
|
self.assertAllClose(self.evaluate(losses[0]), 0.4)
|
|
self.assertAllClose(self.evaluate(losses[1]), 0.4)
|
|
self.assertAllClose(self.evaluate(losses[2]), 0.5)
|
|
with variable_scope.variable_scope("foo", reuse=True):
|
|
# reuse=True is for now only supported when eager execution is disabled.
|
|
if not context.executing_eagerly():
|
|
v = variable_scope.get_variable("v",
|
|
[]) # "v" is already there, reused
|
|
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
|
|
self.assertEqual(3, len(losses)) # No new loss added.
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# ValueError: Tensor-typed variable initializers must either be wrapped in an
|
|
# init_scope or callable...
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testInitializeFromValue(self):
|
|
init = constant_op.constant(0.1)
|
|
w = variable_scope.get_variable("v", initializer=init)
|
|
self.evaluate(variables_lib.variables_initializer([w]))
|
|
self.assertAllClose(self.evaluate(w.value()), 0.1)
|
|
|
|
with self.assertRaisesRegexp(ValueError, "shape"):
|
|
# We disallow explicit shape specification when initializer is constant.
|
|
variable_scope.get_variable("u", [1], initializer=init)
|
|
|
|
with variable_scope.variable_scope("foo", initializer=init):
|
|
# Constant initializer can be passed through scopes if needed.
|
|
v = variable_scope.get_variable("v")
|
|
self.evaluate(variables_lib.variables_initializer([v]))
|
|
self.assertAllClose(self.evaluate(v.value()), 0.1)
|
|
|
|
# Check that non-float32 initializer creates a non-float32 variable.
|
|
init = constant_op.constant(1, dtype=dtypes.int32)
|
|
t = variable_scope.get_variable("t", initializer=init)
|
|
self.assertEqual(t.dtype.base_dtype, dtypes.int32)
|
|
|
|
# Raise error if `initializer` dtype and `dtype` are not identical.
|
|
with self.assertRaisesRegexp(ValueError, "don't match"):
|
|
variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# TypeError: Fetch argument <tf.Variable 'v0:0' shape=(1,) dtype=float32> has
|
|
# invalid type <class '...ops.resource_variable_ops.ResourceVariable'>, must
|
|
# be a string or Tensor. (Can not convert a ResourceVariable into a Tensor or
|
|
# Operation.)
|
|
@test_util.run_deprecated_v1
|
|
def testControlDeps(self):
|
|
with self.cached_session() as sess:
|
|
v0 = variable_scope.get_variable(
|
|
"v0", [1], initializer=init_ops.constant_initializer(0))
|
|
with ops.control_dependencies([v0.value()]):
|
|
v1 = variable_scope.get_variable(
|
|
"v1", [1], initializer=init_ops.constant_initializer(1))
|
|
add = v1 + v0
|
|
# v0 should be uninitialized.
|
|
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
|
|
self.evaluate(v0)
|
|
# We should be able to initialize and run v1 without initializing
|
|
# v0, even if the variable was created with a control dep on v0.
|
|
self.evaluate(v1.initializer)
|
|
self.assertEqual(1, self.evaluate(v1))
|
|
# v0 should still be uninitialized.
|
|
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
|
|
self.evaluate(v0)
|
|
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
|
|
self.evaluate(add)
|
|
# If we initialize v0 we should be able to run 'add'.
|
|
self.evaluate(v0.initializer)
|
|
self.evaluate(add)
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# AssertionError: True is not false (last assertFalse)
|
|
@test_util.run_deprecated_v1
|
|
def testEnableResourceVariables(self):
|
|
old = variable_scope._DEFAULT_USE_RESOURCE
|
|
try:
|
|
variable_scope.enable_resource_variables()
|
|
self.assertTrue(isinstance(variables_lib.VariableV1(1.0),
|
|
resource_variable_ops.ResourceVariable))
|
|
variable_scope.disable_resource_variables()
|
|
self.assertFalse(isinstance(variables_lib.VariableV1(1.0),
|
|
resource_variable_ops.ResourceVariable))
|
|
finally:
|
|
variable_scope._DEFAULT_USE_RESOURCE = old
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# TypeError: Fetch argument None has invalid type <type 'NoneType'>
|
|
@test_util.run_deprecated_v1
|
|
def testControlFlow(self):
|
|
with self.cached_session() as sess:
|
|
v0 = variable_scope.get_variable(
|
|
"v0", [], initializer=init_ops.constant_initializer(0))
|
|
var_dict = {}
|
|
|
|
# Call get_variable in each of the cond clauses.
|
|
def var_in_then_clause():
|
|
v1 = variable_scope.get_variable(
|
|
"v1", [1], initializer=init_ops.constant_initializer(1))
|
|
var_dict["v1"] = v1
|
|
return v1 + v0
|
|
|
|
def var_in_else_clause():
|
|
v2 = variable_scope.get_variable(
|
|
"v2", [1], initializer=init_ops.constant_initializer(2))
|
|
var_dict["v2"] = v2
|
|
return v2 + v0
|
|
|
|
add = control_flow_ops.cond(
|
|
math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause)
|
|
v1 = var_dict["v1"]
|
|
v2 = var_dict["v2"]
|
|
# We should be able to initialize and run v1 and v2 without initializing
|
|
# v0, even if the variable was created with a control dep on v0.
|
|
self.evaluate(v1.initializer)
|
|
self.assertEqual([1], self.evaluate(v1))
|
|
self.evaluate(v2.initializer)
|
|
self.assertEqual([2], self.evaluate(v2))
|
|
# v0 should still be uninitialized.
|
|
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
|
|
self.evaluate(v0)
|
|
# We should not be able to run 'add' yet.
|
|
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
|
|
self.evaluate(add)
|
|
# If we initialize v0 we should be able to run 'add'.
|
|
self.evaluate(v0.initializer)
|
|
self.evaluate(add)
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# TypeError: Expected tf.group() expected Tensor arguments not 'None' with
|
|
# type '<type 'NoneType'>'.
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testGetVariableScope(self):
|
|
# Test the get_variable_scope() function and setting properties of result.
|
|
init = init_ops.constant_initializer(0.3)
|
|
with variable_scope.variable_scope("bar"):
|
|
new_init1 = variable_scope.get_variable_scope().initializer
|
|
self.assertEqual(new_init1, None)
|
|
# Check that we can set initializer like this.
|
|
variable_scope.get_variable_scope().set_initializer(init)
|
|
v = variable_scope.get_variable("v", [])
|
|
self.evaluate(variables_lib.variables_initializer([v]))
|
|
self.assertAllClose(self.evaluate(v.value()), 0.3)
|
|
if not context.executing_eagerly():
|
|
# Check that we can set reuse.
|
|
variable_scope.get_variable_scope().reuse_variables()
|
|
with self.assertRaises(ValueError): # Fail, w does not exist yet.
|
|
variable_scope.get_variable("w", [1])
|
|
# Check that the set initializer goes away.
|
|
new_init = variable_scope.get_variable_scope().initializer
|
|
self.assertEqual(new_init, None)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScope(self):
|
|
with variable_scope.variable_scope("tower4") as tower:
|
|
self.assertEqual(tower.name, "tower4")
|
|
with ops.name_scope("scope") as sc:
|
|
self.assertEqual(sc, "tower4/scope/")
|
|
|
|
with variable_scope.variable_scope("tower5"):
|
|
with variable_scope.variable_scope("bar") as bar:
|
|
self.assertEqual(bar.name, "tower5/bar")
|
|
with ops.name_scope("scope") as sc:
|
|
self.assertEqual(sc, "tower5/bar/scope/")
|
|
|
|
with variable_scope.variable_scope("tower6"):
|
|
with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
|
|
self.assertEqual(tower_shared.name, "tower4")
|
|
with ops.name_scope("scope") as sc:
|
|
self.assertEqual(sc, "tower6/tower4/scope/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeNameScope(self):
|
|
with ops.name_scope("testVarScopeNameScope1"):
|
|
with variable_scope.variable_scope("tower") as tower:
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
|
|
if not context.executing_eagerly():
|
|
with variable_scope.variable_scope(
|
|
tower): # Re-entering acts like another "tower".
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/")
|
|
with variable_scope.variable_scope(
|
|
"tower"): # Re-entering by string acts the same.
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/")
|
|
|
|
with ops.name_scope("testVarScopeNameScope2"):
|
|
with variable_scope.variable_scope("tower"):
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
|
|
if not context.executing_eagerly():
|
|
with variable_scope.variable_scope(tower):
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
|
|
|
|
root_var_scope = variable_scope.get_variable_scope()
|
|
with ops.name_scope("testVarScopeNameScope3"):
|
|
with variable_scope.variable_scope(root_var_scope):
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeOriginalNameScope(self):
|
|
with self.cached_session():
|
|
with ops.name_scope("scope1"):
|
|
with variable_scope.variable_scope("tower") as tower:
|
|
self.assertEqual(tower.original_name_scope, "scope1/tower/")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "scope1/tower/scope2/")
|
|
with ops.name_scope("scope2"):
|
|
with variable_scope.variable_scope(tower) as tower1:
|
|
# Re-entering preserves original name scope.
|
|
self.assertEqual(tower1.original_name_scope, "scope1/tower/")
|
|
with ops.name_scope("foo") as sc2:
|
|
self.assertEqual(sc2, "scope2/tower/foo/")
|
|
# Test re-entering original name scope.
|
|
with ops.name_scope(tower.original_name_scope):
|
|
with ops.name_scope("bar") as sc3:
|
|
self.assertEqual(sc3, "scope1/tower/bar/")
|
|
with ops.name_scope("scope2"):
|
|
with variable_scope.variable_scope(tower):
|
|
with ops.name_scope(tower.original_name_scope):
|
|
with ops.name_scope("bar") as sc3:
|
|
self.assertEqual(sc3, "scope1/tower/bar_1/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeObjectReuse(self):
|
|
with self.cached_session():
|
|
vs = None
|
|
with variable_scope.variable_scope("jump", reuse=True) as scope:
|
|
vs = scope
|
|
|
|
with variable_scope.variable_scope(vs) as jump:
|
|
self.assertTrue(jump.reuse)
|
|
|
|
with variable_scope.variable_scope(vs, reuse=True) as jump_reuse:
|
|
self.assertTrue(jump_reuse.reuse)
|
|
|
|
with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
|
|
self.assertTrue(jump_no_reuse.reuse) # Inherited, cannot be undone.
|
|
|
|
with variable_scope.variable_scope("jump", reuse=False) as scope:
|
|
vs = scope
|
|
|
|
with variable_scope.variable_scope(vs) as jump:
|
|
self.assertFalse(jump.reuse)
|
|
|
|
with variable_scope.variable_scope(vs, reuse=True) as jump_reuse:
|
|
self.assertTrue(jump_reuse.reuse)
|
|
|
|
with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
|
|
self.assertFalse(jump_no_reuse.reuse)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeGetOrCreateReuse(self):
|
|
with self.cached_session():
|
|
|
|
def test_value(value):
|
|
x = constant_op.constant(value)
|
|
with variable_scope.variable_scope(
|
|
"testVarScopeGetOrCreateReuse_bar",
|
|
reuse=variable_scope.AUTO_REUSE):
|
|
_ = state_ops.assign(variable_scope.get_variable("var", []), x)
|
|
with variable_scope.variable_scope(
|
|
"testVarScopeGetOrCreateReuse_bar",
|
|
reuse=variable_scope.AUTO_REUSE):
|
|
_ = variable_scope.get_variable("var", [])
|
|
self.assertEqual(value, self.evaluate(x))
|
|
|
|
test_value(42.) # Variable is created.
|
|
test_value(13.) # Variable is reused hereafter.
|
|
test_value(17.)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarOpScope(self):
|
|
with self.cached_session():
|
|
with ops.name_scope("testVarOpScope1"):
|
|
with variable_scope.variable_scope("tower", "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "tower/w:0")
|
|
with ops.name_scope("testVarOpScope2") as sc2:
|
|
self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/")
|
|
with variable_scope.variable_scope("tower", "default", []):
|
|
with self.assertRaises(ValueError):
|
|
variable_scope.get_variable("w", [])
|
|
with ops.name_scope("testVarOpScope2") as sc2:
|
|
self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/")
|
|
|
|
with ops.name_scope("testVarOpScope2"):
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "default/w:0")
|
|
with ops.name_scope("testVarOpScope2") as sc2:
|
|
self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/")
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "default_1/w:0")
|
|
with ops.name_scope("testVarOpScope2") as sc2:
|
|
self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope(None, "defaultScope1"):
|
|
with variable_scope.variable_scope(None, "layer"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name,
|
|
"defaultScope1/layer/w:0")
|
|
with variable_scope.variable_scope(None, "defaultScope1"):
|
|
with variable_scope.variable_scope(None, "layer"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name,
|
|
"defaultScope1_1/layer/w:0")
|
|
with variable_scope.variable_scope(None, "defaultScope"):
|
|
with variable_scope.variable_scope(None, "layer"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name,
|
|
"defaultScope/layer/w:0")
|
|
with variable_scope.variable_scope(None, "defaultScope1"):
|
|
with variable_scope.variable_scope(None, "layer"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name,
|
|
"defaultScope1_2/layer/w:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarOpScopeUniqueNamesWithJump(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("default") as default:
|
|
with variable_scope.variable_scope(None, "layer"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "default/layer/w:0")
|
|
with variable_scope.variable_scope(None, "layer"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name,
|
|
"default/layer_1/w:0")
|
|
with variable_scope.variable_scope(default):
|
|
pass
|
|
# No matter the jump in the middle, unique numbering continues.
|
|
with variable_scope.variable_scope(None, "layer"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name,
|
|
"default/layer_2/w:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarOpScopeReuse(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("outer") as outer:
|
|
with variable_scope.variable_scope("tower", "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/tower/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/tower/scope2/")
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/default/scope2/")
|
|
|
|
with variable_scope.variable_scope(outer, reuse=True) as outer:
|
|
with variable_scope.variable_scope("tower", "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/tower/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/tower/scope2/")
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/default/scope2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeGetVar(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("root"):
|
|
with variable_scope.variable_scope("towerA") as tower_a:
|
|
va = variable_scope.get_variable("v", [1])
|
|
self.assertEqual(va.name, "root/towerA/v:0")
|
|
|
|
with variable_scope.variable_scope(tower_a, reuse=True):
|
|
va2 = variable_scope.get_variable("v", [1])
|
|
self.assertIs(va2, va)
|
|
|
|
with variable_scope.variable_scope("towerB"):
|
|
vb = variable_scope.get_variable("v", [1])
|
|
self.assertEqual(vb.name, "root/towerB/v:0")
|
|
|
|
with self.assertRaises(ValueError):
|
|
with variable_scope.variable_scope("towerA"):
|
|
va2 = variable_scope.get_variable("v", [1])
|
|
|
|
with variable_scope.variable_scope("towerA", reuse=True):
|
|
va2 = variable_scope.get_variable("v", [1])
|
|
self.assertIs(va2, va)
|
|
|
|
with variable_scope.variable_scope("foo"):
|
|
with variable_scope.variable_scope("bar"):
|
|
v = variable_scope.get_variable("v", [1])
|
|
self.assertEqual(v.name, "root/foo/bar/v:0")
|
|
with variable_scope.variable_scope(tower_a, reuse=True):
|
|
va3 = variable_scope.get_variable("v", [1])
|
|
self.assertIs(va, va3)
|
|
|
|
with self.assertRaises(ValueError):
|
|
with variable_scope.variable_scope(tower_a, reuse=True):
|
|
with variable_scope.variable_scope("baz"):
|
|
variable_scope.get_variable("v", [1])
|
|
|
|
with self.assertRaises(ValueError) as exc:
|
|
with variable_scope.variable_scope(tower_a, reuse=True):
|
|
variable_scope.get_variable("v", [2]) # Different shape.
|
|
self.assertEqual("shape" in str(exc.exception), True)
|
|
|
|
with self.assertRaises(ValueError) as exc:
|
|
with variable_scope.variable_scope(tower_a, reuse=True):
|
|
variable_scope.get_variable("v", [1], dtype=dtypes.int32)
|
|
self.assertEqual("dtype" in str(exc.exception), True)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeOuterScope(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("outer") as outer:
|
|
pass
|
|
with variable_scope.variable_scope(outer):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/scope2/")
|
|
with variable_scope.variable_scope("default"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/default/scope2/")
|
|
|
|
with variable_scope.variable_scope(outer, reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_2/scope2/")
|
|
with variable_scope.variable_scope("default", reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_2/default/scope2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarScopeNestedOuterScope(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("outer") as outer:
|
|
with variable_scope.variable_scope(outer):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/outer/scope2/")
|
|
with variable_scope.variable_scope("default"):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/default/scope2/")
|
|
|
|
with variable_scope.variable_scope(outer, reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/outer_1/scope2/")
|
|
with variable_scope.variable_scope("default", reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/default_1/scope2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarOpScopeReuseParam(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("outer") as outer:
|
|
with variable_scope.variable_scope("tower", "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/tower/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/tower/scope2/")
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/default/scope2/")
|
|
|
|
with variable_scope.variable_scope(outer) as outer:
|
|
with variable_scope.variable_scope("tower", "default", reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/tower/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/tower/scope2/")
|
|
outer.reuse_variables()
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/default/scope2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarOpScopeReuseError(self):
|
|
with self.cached_session():
|
|
with self.assertRaises(ValueError):
|
|
with variable_scope.variable_scope(None, "default", reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/tower/w:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarOpScopeOuterScope(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("outer") as outer:
|
|
pass
|
|
with variable_scope.variable_scope(outer, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/scope2/")
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/default/scope2/")
|
|
|
|
with variable_scope.variable_scope(outer, "default", reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_2/scope2/")
|
|
outer.reuse_variables()
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_2/default/scope2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVarOpScopeNestedOuterScope(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("outer") as outer:
|
|
with variable_scope.variable_scope(outer, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/outer/scope2/")
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer/default/scope2/")
|
|
|
|
with variable_scope.variable_scope(outer, "default", reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/scope2/")
|
|
with variable_scope.variable_scope(None, "default", []):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
with ops.name_scope("scope2") as sc2:
|
|
self.assertEqual(sc2, "outer_1/default/scope2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testBasicWhenAuxiliaryNameScopeIsFalse(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope(
|
|
"scope", auxiliary_name_scope=False) as scope:
|
|
self.assertEqual(scope.original_name_scope, "")
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "scope/w:0")
|
|
self.assertEqual(constant_op.constant([], name="c").name, "c:0")
|
|
with variable_scope.variable_scope(scope, auxiliary_name_scope=False):
|
|
self.assertEqual(scope.original_name_scope, "")
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w1", []).name, "scope/w1:0")
|
|
self.assertEqual(constant_op.constant([], name="c1").name, "c1:0")
|
|
# Recheck: new name scope is NOT created before
|
|
with ops.name_scope("scope"):
|
|
self.assertEqual(constant_op.constant([], name="c").name, "scope/c:0")
|
|
|
|
with variable_scope.variable_scope("outer"):
|
|
with variable_scope.variable_scope(
|
|
"inner", auxiliary_name_scope=False) as inner:
|
|
self.assertEqual(inner.original_name_scope, "outer/")
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/inner/w:0")
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c").name, "outer/c:0")
|
|
with variable_scope.variable_scope(
|
|
inner, auxiliary_name_scope=False) as inner1:
|
|
self.assertEqual(inner1.original_name_scope, "outer/")
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w1", []).name, "outer/inner/w1:0")
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c1").name, "outer/c1:0")
|
|
# Recheck: new name scope is NOT created before
|
|
with ops.name_scope("inner"):
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c").name, "outer/inner/c:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope(
|
|
None, default_name="default", auxiliary_name_scope=False) as scope:
|
|
self.assertEqual(scope.original_name_scope, "")
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "default/w:0")
|
|
self.assertEqual(constant_op.constant([], name="c").name, "c:0")
|
|
# Recheck: new name scope is NOT created before
|
|
with ops.name_scope("default"):
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c").name, "default/c:0")
|
|
|
|
with variable_scope.variable_scope("outer"):
|
|
with variable_scope.variable_scope(
|
|
None, default_name="default",
|
|
auxiliary_name_scope=False) as inner:
|
|
self.assertEqual(inner.original_name_scope, "outer/")
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/default/w:0")
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c").name, "outer/c:0")
|
|
# Recheck: new name scope is NOT created before
|
|
with ops.name_scope("default"):
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c").name, "outer/default/c:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
|
|
with self.cached_session():
|
|
root_scope = variable_scope.get_variable_scope()
|
|
with variable_scope.variable_scope(
|
|
root_scope, auxiliary_name_scope=False) as scope:
|
|
self.assertEqual(scope.original_name_scope, "")
|
|
self.assertEqual(variable_scope.get_variable("w", []).name, "w:0")
|
|
self.assertEqual(constant_op.constant([], name="c").name, "c:0")
|
|
|
|
with variable_scope.variable_scope("outer"):
|
|
with variable_scope.variable_scope(
|
|
root_scope, auxiliary_name_scope=False) as inner:
|
|
self.assertEqual(inner.original_name_scope, "")
|
|
self.assertEqual(variable_scope.get_variable("w1", []).name, "w1:0")
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c1").name, "outer/c1:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testAuxiliaryNameScopeIsInvalid(self):
|
|
with self.cached_session():
|
|
with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
|
|
with variable_scope.variable_scope(
|
|
None, default_name="scope", auxiliary_name_scope="invalid"):
|
|
pass
|
|
|
|
with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
|
|
with variable_scope.variable_scope(
|
|
"scope", auxiliary_name_scope="invalid"):
|
|
pass
|
|
|
|
with variable_scope.variable_scope("scope") as scope:
|
|
pass
|
|
with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
|
|
with variable_scope.variable_scope(
|
|
scope, auxiliary_name_scope="invalid"):
|
|
pass
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testReuseScopeWithoutNameScopeCollision(self):
|
|
# Github issue: #13429
|
|
with self.cached_session():
|
|
with variable_scope.variable_scope("outer"):
|
|
with variable_scope.variable_scope("inner") as inner:
|
|
pass
|
|
|
|
with variable_scope.variable_scope(
|
|
inner, auxiliary_name_scope=False) as scope:
|
|
with ops.name_scope(scope.original_name_scope):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w", []).name, "outer/inner/w:0")
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c").name, "outer/inner/c:0")
|
|
with ops.name_scope("inner"):
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c").name, "inner/c:0")
|
|
|
|
with variable_scope.variable_scope("another"):
|
|
with variable_scope.variable_scope(
|
|
inner, auxiliary_name_scope=False) as scope1:
|
|
with ops.name_scope(scope1.original_name_scope):
|
|
self.assertEqual(
|
|
variable_scope.get_variable("w1", []).name,
|
|
"outer/inner/w1:0")
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c1").name, "outer/inner/c1:0")
|
|
with ops.name_scope("inner"):
|
|
self.assertEqual(
|
|
constant_op.constant([], name="c").name, "another/inner/c:0")
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
# (different assertions failing after wrapping, in both execution modes)
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testGetLocalVar(self):
|
|
# Check that local variable respects naming.
|
|
with variable_scope.variable_scope("outer") as outer:
|
|
with variable_scope.variable_scope(outer, "default", []):
|
|
local_var = variable_scope.get_local_variable(
|
|
"w", [], collections=["foo"])
|
|
self.assertEqual(local_var.name, "outer/w:0")
|
|
|
|
if not context.executing_eagerly():
|
|
# Since variable is local, it should be in the local variable collection
|
|
# but not the trainable collection.
|
|
self.assertIn(local_var,
|
|
ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
|
|
self.assertIn(local_var, ops.get_collection("foo"))
|
|
self.assertNotIn(local_var,
|
|
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
|
# Check that local variable respects `reuse`.
|
|
with variable_scope.variable_scope(outer, "default", reuse=True):
|
|
self.assertEqual(
|
|
variable_scope.get_local_variable("w", []).name, "outer/w:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testSignatureGetVarVsGetLocalVar(self):
|
|
"""get_{local,}variable() must take the same list of args."""
|
|
arg_names = tf_inspect.getargspec(variable_scope.get_variable)[0]
|
|
local_arg_names = tf_inspect.getargspec(
|
|
variable_scope.get_local_variable)[0]
|
|
self.assertEqual(arg_names, local_arg_names)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testGetVarWithDevice(self):
|
|
g = ops.Graph()
|
|
varname_type = []
|
|
|
|
def device_func(op):
|
|
if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
|
|
varname_type.append((op.name, op.get_attr("dtype")))
|
|
return "/device:GPU:0"
|
|
|
|
with g.as_default():
|
|
with ops.device(device_func):
|
|
_ = variable_scope.get_variable("x", (100, 200))
|
|
_ = variable_scope.get_variable(
|
|
"y", dtype=dtypes.int64, initializer=numpy.arange(73))
|
|
self.assertEqual(varname_type[0], ("x", dtypes.float32))
|
|
self.assertEqual(varname_type[1], ("y", dtypes.int64))
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
@test_util.run_deprecated_v1
|
|
def testGetCollection(self):
|
|
with self.cached_session():
|
|
_ = variable_scope.get_variable("testGetCollection_a", [])
|
|
_ = variable_scope.get_variable(
|
|
"testGetCollection_b", [], trainable=False)
|
|
with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
|
|
_ = variable_scope.get_variable("testGetCollection_a", [])
|
|
_ = variable_scope.get_variable(
|
|
"testGetCollection_b", [], trainable=False)
|
|
self.assertEqual([
|
|
v.name
|
|
for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
|
], ["testGetCollection_foo_/testGetCollection_a:0"])
|
|
self.assertEqual([
|
|
v.name
|
|
for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
], [
|
|
"testGetCollection_foo_/testGetCollection_a:0",
|
|
"testGetCollection_foo_/testGetCollection_b:0"
|
|
])
|
|
with variable_scope.variable_scope("testGetCollection_foo") as scope2:
|
|
_ = variable_scope.get_variable("testGetCollection_a", [])
|
|
_ = variable_scope.get_variable(
|
|
"testGetCollection_b", [], trainable=False)
|
|
self.assertEqual([
|
|
v.name
|
|
for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
|
], ["testGetCollection_foo/testGetCollection_a:0"])
|
|
self.assertEqual([
|
|
v.name
|
|
for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
], [
|
|
"testGetCollection_foo/testGetCollection_a:0",
|
|
"testGetCollection_foo/testGetCollection_b:0"
|
|
])
|
|
scope = variable_scope.get_variable_scope()
|
|
self.assertEqual([
|
|
v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
], [
|
|
"testGetCollection_a:0", "testGetCollection_b:0",
|
|
"testGetCollection_foo_/testGetCollection_a:0",
|
|
"testGetCollection_foo_/testGetCollection_b:0",
|
|
"testGetCollection_foo/testGetCollection_a:0",
|
|
"testGetCollection_foo/testGetCollection_b:0"
|
|
])
|
|
self.assertEqual([
|
|
v.name
|
|
for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
|
], [
|
|
"testGetCollection_a:0",
|
|
"testGetCollection_foo_/testGetCollection_a:0",
|
|
"testGetCollection_foo/testGetCollection_a:0"
|
|
])
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
@test_util.run_deprecated_v1
|
|
def testGetTrainableVariablesWithGetVariable(self):
|
|
with self.cached_session():
|
|
_ = variable_scope.get_variable("testGetTrainableVariables_a", [])
|
|
with variable_scope.variable_scope(
|
|
"testGetTrainableVariables_foo") as scope:
|
|
_ = variable_scope.get_variable("testGetTrainableVariables_b", [])
|
|
_ = variable_scope.get_variable(
|
|
"testGetTrainableVariables_c", [], trainable=False)
|
|
|
|
# sync `ON_READ` sets trainable=False
|
|
_ = variable_scope.get_variable(
|
|
"testGetTrainableVariables_d", [],
|
|
synchronization=variable_scope.VariableSynchronization.ON_READ)
|
|
self.assertEqual(
|
|
[v.name for v in scope.trainable_variables()],
|
|
["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
|
|
|
|
_ = variable_scope.get_variable(
|
|
"testGetTrainableVariables_e", [],
|
|
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
|
trainable=True)
|
|
self.assertEqual([v.name for v in scope.trainable_variables()], [
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_e:0",
|
|
])
|
|
|
|
# All other sync values sets trainable=True
|
|
_ = variable_scope.get_variable(
|
|
"testGetTrainableVariables_f", [],
|
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE)
|
|
self.assertEqual([v.name for v in scope.trainable_variables()], [
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_e:0",
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_f:0",
|
|
])
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
@test_util.run_deprecated_v1
|
|
def testGetTrainableVariablesWithVariable(self):
|
|
with self.cached_session():
|
|
_ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
|
|
with variable_scope.variable_scope(
|
|
"testGetTrainableVariables_foo") as scope:
|
|
_ = variable_scope.variable(1.0, name="testGetTrainableVariables_b")
|
|
_ = variable_scope.variable(
|
|
1.0, name="testGetTrainableVariables_c", trainable=False)
|
|
|
|
# sync `ON_READ` sets trainable=False
|
|
_ = variable_scope.variable(
|
|
1.0,
|
|
name="testGetTrainableVariables_d",
|
|
synchronization=variable_scope.VariableSynchronization.ON_READ)
|
|
self.assertEqual(
|
|
[v.name for v in scope.trainable_variables()],
|
|
["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
|
|
|
|
_ = variable_scope.variable(
|
|
1.0,
|
|
name="testGetTrainableVariables_e",
|
|
synchronization=variable_scope.VariableSynchronization.ON_READ,
|
|
trainable=True)
|
|
self.assertEqual([v.name for v in scope.trainable_variables()], [
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_e:0",
|
|
])
|
|
|
|
# All other sync values sets trainable=True
|
|
_ = variable_scope.variable(
|
|
1.0,
|
|
name="testGetTrainableVariables_f",
|
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE)
|
|
self.assertEqual([v.name for v in scope.trainable_variables()], [
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_e:0",
|
|
"testGetTrainableVariables_foo/testGetTrainableVariables_f:0",
|
|
])
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
@test_util.run_deprecated_v1
|
|
def testGetGlobalVariables(self):
|
|
with self.cached_session():
|
|
_ = variable_scope.get_variable("testGetGlobalVariables_a", [])
|
|
with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
|
|
_ = variable_scope.get_variable("testGetGlobalVariables_b", [])
|
|
self.assertEqual(
|
|
[v.name for v in scope.global_variables()],
|
|
["testGetGlobalVariables_foo/"
|
|
"testGetGlobalVariables_b:0"])
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
@test_util.run_deprecated_v1
|
|
def testGetLocalVariables(self):
|
|
with self.cached_session():
|
|
_ = variable_scope.get_variable(
|
|
"a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
|
with variable_scope.variable_scope("foo") as scope:
|
|
_ = variable_scope.get_variable(
|
|
"b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
|
_ = variable_scope.get_variable("c", [])
|
|
self.assertEqual([v.name for v in scope.local_variables()], ["foo/b:0"])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testGetVariableWithRefDtype(self):
|
|
v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)
|
|
# Ensure it is possible to do get_variable with a _ref dtype passed in.
|
|
_ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testGetVariableWithInitializerWhichTakesNoArgs(self):
|
|
v = variable_scope.get_variable("foo", initializer=lambda: [2])
|
|
self.assertEqual(v.name, "foo:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testGetVariableWithInitializerWhichTakesOptionalArgs(self):
|
|
v = variable_scope.get_variable("foo", initializer=lambda x=True: [2])
|
|
self.assertEqual(v.name, "foo:0")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testGetVariableWithInitializerWhichTakesUnprovidedArgsAndNoShape(self):
|
|
with self.assertRaisesRegexp(
|
|
ValueError,
|
|
"The initializer passed is not valid. It should be a callable with no "
|
|
"arguments and the shape should not be provided or an instance of "
|
|
"`tf.keras.initializers.*' and `shape` should be fully defined."):
|
|
variable_scope.get_variable("foo", initializer=lambda x: [2])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testTwoGraphs(self):
|
|
|
|
def f():
|
|
g1 = ops.Graph()
|
|
g2 = ops.Graph()
|
|
with g1.as_default():
|
|
with g2.as_default():
|
|
with variable_scope.variable_scope("_"):
|
|
pass
|
|
|
|
self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f)
|
|
|
|
|
|
def axis0_into1_partitioner(shape=None, **unused_kwargs):
|
|
part = [1] * len(shape)
|
|
return part
|
|
|
|
|
|
def axis0_into2_partitioner(shape=None, **unused_kwargs):
|
|
part = [1] * len(shape)
|
|
part[0] = 2
|
|
return part
|
|
|
|
|
|
def axis0_into3_partitioner(shape=None, **unused_kwargs):
|
|
part = [1] * len(shape)
|
|
part[0] = 3
|
|
return part
|
|
|
|
|
|
class VariableScopeWithPartitioningTest(test.TestCase):
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
@test_util.run_deprecated_v1
|
|
def testResultNameMatchesRequested(self):
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into2_partitioner):
|
|
v = variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
self.assertEqual(v.name, "scope0/name0")
|
|
v_concat = v.as_tensor()
|
|
self.assertEqual(v_concat.name, "scope0/name0:0")
|
|
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
self.assertIn("scope0/name0/part_0:0", [x.name for x in variables])
|
|
self.assertIn("scope0/name0/part_1:0", [x.name for x in variables])
|
|
self.assertNotIn("scope0/name0/part_2:0", [x.name for x in variables])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testBreaksIfPartitioningChanges(self):
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into2_partitioner):
|
|
variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into3_partitioner, reuse=True):
|
|
with self.assertRaisesRegexp(
|
|
ValueError,
|
|
"Trying to reuse partitioned variable .* but specified partitions "
|
|
".* and found partitions .*"):
|
|
variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into1_partitioner, reuse=True):
|
|
with self.assertRaisesRegexp(
|
|
ValueError,
|
|
"Trying to reuse partitioned variable .* but specified partitions "
|
|
".* and found partitions .*"):
|
|
variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testReturnsExistingConcatenatedValueIfReuse(self):
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into2_partitioner):
|
|
v_concat = variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
variable_scope.get_variable_scope().reuse_variables()
|
|
v_concat_2 = variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
self.assertEqual(v_concat, v_concat_2)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testAllowsReuseWithoutPartitioner(self):
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into2_partitioner):
|
|
v = variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
with variable_scope.variable_scope("scope0", reuse=True):
|
|
v_reused = variable_scope.get_variable("name0")
|
|
self.assertIs(v, v_reused)
|
|
|
|
def testNoReuseInEagerByDefault(self):
|
|
with context.eager_mode():
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into2_partitioner):
|
|
v1 = variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
v2 = variable_scope.get_variable("name0", shape=(3, 1, 1))
|
|
self.assertIsNot(v1, v2)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testPropagatePartitionerOnReopening(self):
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into2_partitioner) as vs:
|
|
self.assertEqual(axis0_into2_partitioner, vs.partitioner)
|
|
with variable_scope.variable_scope(vs) as vs1:
|
|
self.assertEqual(axis0_into2_partitioner, vs1.partitioner)
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# obtaining different results in the eager case compared to the graph one
|
|
@test_util.run_deprecated_v1
|
|
def testScalarIgnoresPartitioner(self):
|
|
with variable_scope.variable_scope(
|
|
"scope0", partitioner=axis0_into2_partitioner):
|
|
v = variable_scope.get_variable("name0", shape=())
|
|
self.assertEqual(v.name, "scope0/name0:0")
|
|
variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
|
self.assertIn("scope0/name0:0", [x.name for x in variables])
|
|
|
|
def _testPartitionConcatenatesAlongCorrectAxis(self, use_resource):
|
|
def _part_axis_0(**unused_kwargs):
|
|
return (2, 1, 1)
|
|
|
|
def _part_axis_1(**unused_kwargs):
|
|
return (1, 2, 1)
|
|
|
|
with variable_scope.variable_scope("root", use_resource=use_resource):
|
|
v0 = variable_scope.get_variable(
|
|
"n0", shape=(2, 2, 2), partitioner=_part_axis_0)
|
|
v1 = variable_scope.get_variable(
|
|
"n1", shape=(2, 2, 2), partitioner=_part_axis_1)
|
|
|
|
self.assertEqual(v0.get_shape(), (2, 2, 2))
|
|
self.assertEqual(v1.get_shape(), (2, 2, 2))
|
|
|
|
n0_0 = list(v0)[0]
|
|
n0_1 = list(v0)[1]
|
|
self.assertEqual(n0_0.get_shape(), (1, 2, 2))
|
|
self.assertEqual(n0_1.get_shape(), (1, 2, 2))
|
|
|
|
n1_0 = list(v1)[0]
|
|
n1_1 = list(v1)[1]
|
|
self.assertEqual(n1_0.get_shape(), (2, 1, 2))
|
|
self.assertEqual(n1_1.get_shape(), (2, 1, 2))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testPartitionConcatenatesAlongCorrectAxis(self):
|
|
self._testPartitionConcatenatesAlongCorrectAxis(use_resource=False)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testPartitionConcatenatesAlongCorrectAxisResource(self):
|
|
self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
|
|
|
|
def testPartitionConcatenatesAlongCorrectAxisResourceInEager(self):
|
|
with context.eager_mode():
|
|
self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
|
|
|
|
|
|
class VariableScopeWithCustomGetterTest(test.TestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testNonCallableGetterFails(self):
|
|
with self.assertRaisesRegexp(ValueError,
|
|
r"custom_getter .* not callable:"):
|
|
with variable_scope.variable_scope("scope0", custom_getter=3):
|
|
variable_scope.get_variable("name0")
|
|
with self.assertRaisesRegexp(ValueError,
|
|
r"custom_getter .* not callable:"):
|
|
variable_scope.get_variable("name0", custom_getter=3)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testNoSideEffectsWithIdentityCustomGetter(self):
|
|
called = [0]
|
|
|
|
def custom_getter(getter, *args, **kwargs):
|
|
called[0] += 1
|
|
return getter(*args, **kwargs)
|
|
|
|
with variable_scope.variable_scope(
|
|
"scope", custom_getter=custom_getter) as scope:
|
|
v = variable_scope.get_variable("v", [1])
|
|
with variable_scope.variable_scope(scope, reuse=True):
|
|
v2 = variable_scope.get_variable("v", [1])
|
|
with variable_scope.variable_scope("new_scope") as new_scope:
|
|
v3 = variable_scope.get_variable("v3", [1])
|
|
with variable_scope.variable_scope(
|
|
new_scope, reuse=True, custom_getter=custom_getter):
|
|
v4 = variable_scope.get_variable("v3", [1])
|
|
|
|
self.assertIs(v, v2)
|
|
self.assertIs(v3, v4)
|
|
self.assertEqual(3, called[0]) # skipped one in the first new_scope
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testSynchronizationAndAggregationWithCustomGetter(self):
|
|
called = [0]
|
|
synchronization = variable_scope.VariableSynchronization.AUTO
|
|
aggregation = variable_scope.VariableAggregation.NONE
|
|
|
|
def custom_getter(getter, *args, **kwargs):
|
|
called[0] += 1
|
|
|
|
# Verify synchronization and aggregation kwargs are as expected.
|
|
self.assertEqual(kwargs["synchronization"], synchronization)
|
|
self.assertEqual(kwargs["aggregation"], aggregation)
|
|
return getter(*args, **kwargs)
|
|
|
|
with variable_scope.variable_scope("scope", custom_getter=custom_getter):
|
|
variable_scope.get_variable("v", [1])
|
|
self.assertEqual(1, called[0])
|
|
|
|
with variable_scope.variable_scope("scope", custom_getter=custom_getter):
|
|
synchronization = variable_scope.VariableSynchronization.ON_READ
|
|
aggregation = variable_scope.VariableAggregation.MEAN
|
|
variable_scope.get_variable(
|
|
"v1", [1], synchronization=synchronization, aggregation=aggregation)
|
|
|
|
self.assertEqual(2, called[0])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testCustomGetterWithReuse(self):
|
|
# Custom getter can choose to behave differently on reused variables.
|
|
def custom_getter(getter, *args, **kwargs):
|
|
var = getter(*args, **kwargs)
|
|
if kwargs["reuse"]:
|
|
# This can be used, e.g., for changing the caching device if needed.
|
|
return array_ops.identity(var, name="reused")
|
|
else:
|
|
return array_ops.identity(var, name="not_reused")
|
|
|
|
with variable_scope.variable_scope(
|
|
"scope", custom_getter=custom_getter) as scope:
|
|
v = variable_scope.get_variable("v", [1])
|
|
with variable_scope.variable_scope(scope, reuse=True):
|
|
v2 = variable_scope.get_variable("v", [1])
|
|
|
|
self.assertEqual(v.name, "not_reused:0")
|
|
self.assertEqual(v2.name, "reused:0")
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# ValueError: Fetch argument <tf.Tensor 'custom_getter/add:0' shape=(1, 2, 3)
|
|
# dtype=float32> cannot be interpreted as a Tensor. (Tensor
|
|
# Tensor("custom_getter/add:0", shape=(1, 2, 3), dtype=float32) is not an
|
|
# element of this graph.)
|
|
@test_util.run_deprecated_v1
|
|
def testGetterThatCreatesTwoVariablesAndSumsThem(self):
|
|
|
|
def custom_getter(getter, name, *args, **kwargs):
|
|
g_0 = getter("%s/0" % name, *args, **kwargs)
|
|
g_1 = getter("%s/1" % name, *args, **kwargs)
|
|
with ops.name_scope("custom_getter"):
|
|
return g_0 + g_1
|
|
|
|
with variable_scope.variable_scope("scope", custom_getter=custom_getter):
|
|
v = variable_scope.get_variable("v", [1, 2, 3])
|
|
|
|
self.assertEqual([1, 2, 3], v.get_shape())
|
|
true_vars = variables_lib.trainable_variables()
|
|
self.assertEqual(2, len(true_vars))
|
|
self.assertEqual("scope/v/0:0", true_vars[0].name)
|
|
self.assertEqual("scope/v/1:0", true_vars[1].name)
|
|
self.assertEqual("custom_getter/add:0", v.name)
|
|
with self.cached_session() as sess:
|
|
variables_lib.global_variables_initializer().run()
|
|
np_vars, np_v = self.evaluate([true_vars, v])
|
|
self.assertAllClose(np_v, sum(np_vars))
|
|
|
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
|
# ValueError: Fetch argument <tf.Tensor 'sum_getter_2/add:0' shape=(1, 2, 3)
|
|
# dtype=float32> cannot be interpreted as a Tensor. (Tensor
|
|
# Tensor("sum_getter_2/add:0", shape=(1, 2, 3), dtype=float32) is not an
|
|
# element of this graph.)
|
|
@test_util.run_deprecated_v1
|
|
def testNestedCustomGetters(self):
|
|
|
|
def sum_getter(getter, name, *args, **kwargs):
|
|
g_0 = getter("%s/sum_0" % name, *args, **kwargs)
|
|
g_1 = getter("%s/sum_1" % name, *args, **kwargs)
|
|
with ops.name_scope("sum_getter"):
|
|
return g_0 + g_1
|
|
|
|
def prod_getter(getter, name, *args, **kwargs):
|
|
g_0 = getter("%s/prod_0" % name, *args, **kwargs)
|
|
g_1 = getter("%s/prod_1" % name, *args, **kwargs)
|
|
with ops.name_scope("prod_getter"):
|
|
return g_0 * g_1
|
|
|
|
with variable_scope.variable_scope("prod_scope", custom_getter=prod_getter):
|
|
with variable_scope.variable_scope("sum_scope", custom_getter=sum_getter):
|
|
with variable_scope.variable_scope(
|
|
"inner_sum_scope", custom_getter=sum_getter):
|
|
# take sums of sums of products
|
|
v = variable_scope.get_variable("v", [1, 2, 3])
|
|
|
|
self.assertEqual([1, 2, 3], v.get_shape())
|
|
true_vars = variables_lib.trainable_variables()
|
|
self.assertEqual(8, len(true_vars))
|
|
template = (
|
|
"prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0")
|
|
self.assertEqual(template % (0, 0, 0), true_vars[0].name)
|
|
self.assertEqual(template % (0, 0, 1), true_vars[1].name)
|
|
self.assertEqual(template % (0, 1, 0), true_vars[2].name)
|
|
self.assertEqual(template % (0, 1, 1), true_vars[3].name)
|
|
self.assertEqual(template % (1, 0, 0), true_vars[4].name)
|
|
self.assertEqual(template % (1, 0, 1), true_vars[5].name)
|
|
self.assertEqual(template % (1, 1, 0), true_vars[6].name)
|
|
self.assertEqual(template % (1, 1, 1), true_vars[7].name)
|
|
|
|
with self.cached_session() as sess:
|
|
variables_lib.global_variables_initializer().run()
|
|
np_vars, np_v = self.evaluate([true_vars, v])
|
|
# take products of sums of products
|
|
self.assertAllClose(
|
|
np_v, (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) + (
|
|
(np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7]))))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVariableCreator(self):
|
|
variable_names = []
|
|
|
|
def creator_a(next_creator, **kwargs):
|
|
variable_names.append(kwargs.get("name", ""))
|
|
return next_creator(**kwargs)
|
|
|
|
def creator_b(next_creator, **kwargs):
|
|
kwargs["name"] = "forced_name"
|
|
return next_creator(**kwargs)
|
|
|
|
with variable_scope.variable_creator_scope(creator_a):
|
|
with variable_scope.variable_creator_scope(creator_b):
|
|
variable_scope.variable(1.0, name="one_name")
|
|
|
|
self.assertEqual(variable_names[0], "forced_name")
|
|
|
|
called = [False]
|
|
|
|
def creater_c(next_creator, **kwargs):
|
|
called[0] = True
|
|
self.assertEqual(kwargs["synchronization"],
|
|
variable_scope.VariableSynchronization.ON_WRITE)
|
|
self.assertEqual(kwargs["aggregation"],
|
|
variable_scope.VariableAggregation.MEAN)
|
|
return next_creator(**kwargs)
|
|
|
|
with variable_scope.variable_creator_scope(creater_c):
|
|
variable_scope.get_variable(
|
|
"v", [],
|
|
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
|
|
aggregation=variable_scope.VariableAggregation.MEAN)
|
|
self.assertTrue(called[0])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testVariableCreatorNestingError(self):
|
|
|
|
def creator(next_creator, **kwargs):
|
|
return next_creator(**kwargs)
|
|
|
|
# Save the state so we can clean up at the end.
|
|
graph = ops.get_default_graph()
|
|
old_creator_stack = graph._variable_creator_stack
|
|
|
|
try:
|
|
scope = variable_scope.variable_creator_scope(creator)
|
|
scope.__enter__()
|
|
with variable_scope.variable_creator_scope(creator):
|
|
with self.assertRaises(RuntimeError):
|
|
scope.__exit__(None, None, None)
|
|
finally:
|
|
graph._variable_creator_stack = old_creator_stack
|
|
|
|
|
|
class PartitionInfoTest(test.TestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testConstructorChecks(self):
|
|
# Invalid arg types.
|
|
with self.assertRaises(TypeError):
|
|
variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1])
|
|
with self.assertRaises(TypeError):
|
|
variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None)
|
|
with self.assertRaises(TypeError):
|
|
variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1])
|
|
with self.assertRaises(TypeError):
|
|
variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo")
|
|
|
|
# full_shape and var_offset must have same length.
|
|
with self.assertRaises(ValueError):
|
|
variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0])
|
|
# Offset must always be less than shape.
|
|
with self.assertRaises(ValueError):
|
|
variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testSingleOffset(self):
|
|
partition_info = variable_scope._PartitionInfo(
|
|
full_shape=[9, 3], var_offset=[4, 0])
|
|
self.assertEqual(4, partition_info.single_offset([1, 3]))
|
|
|
|
# Tests when the variable isn't partitioned at all.
|
|
partition_info = variable_scope._PartitionInfo(
|
|
full_shape=[9, 3], var_offset=[0, 0])
|
|
self.assertEqual(0, partition_info.single_offset([9, 3]))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testSingleSliceDim(self):
|
|
partition_info = variable_scope._PartitionInfo(
|
|
full_shape=[9, 3], var_offset=[4, 0])
|
|
# Invalid shape.
|
|
with self.assertRaises(TypeError):
|
|
partition_info.single_slice_dim(None)
|
|
|
|
# Rank of shape differs from full_shape.
|
|
with self.assertRaises(ValueError):
|
|
partition_info.single_slice_dim([1, 2, 3])
|
|
|
|
# Shape is too large given var_offset (4+6 > 9).
|
|
with self.assertRaises(ValueError):
|
|
partition_info.single_slice_dim([6, 3])
|
|
|
|
# Multiple possible slice dim from shape.
|
|
with self.assertRaises(ValueError):
|
|
partition_info.single_slice_dim([1, 1])
|
|
|
|
partition_info = variable_scope._PartitionInfo(
|
|
full_shape=[9, 3], var_offset=[0, 0])
|
|
self.assertEqual(1, partition_info.single_slice_dim([9, 2]))
|
|
partition_info = variable_scope._PartitionInfo(
|
|
full_shape=[9, 3], var_offset=[4, 0])
|
|
self.assertEqual(0, partition_info.single_slice_dim([2, 3]))
|
|
|
|
|
|
class VariableScopeMultithreadedTest(test.TestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testTwoThreadsDisjointScopeEntry(self):
|
|
|
|
def thread_fn(i, graph):
|
|
with graph.as_default():
|
|
with variable_scope.variable_scope("foo"):
|
|
if i == 0:
|
|
v = variable_scope.get_variable("v", [])
|
|
self.assertEquals("foo/v:0", v.name)
|
|
else:
|
|
# Any thread after the first one should fail to create variable
|
|
# with the same name.
|
|
with self.assertRaises(ValueError):
|
|
variable_scope.get_variable("v", [])
|
|
|
|
graph = ops.get_default_graph()
|
|
threads = [
|
|
threading.Thread(target=thread_fn, args=(
|
|
i,
|
|
graph,
|
|
)) for i in range(2)
|
|
]
|
|
|
|
threads[0].start()
|
|
# Allow thread 0 to finish before starting thread 1.
|
|
threads[0].join()
|
|
threads[1].start()
|
|
threads[1].join()
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testTwoThreadsNestedScopeEntry(self):
|
|
|
|
def thread_fn(i, graph, run_event, pause_event):
|
|
with graph.as_default():
|
|
with variable_scope.variable_scope("foo"):
|
|
if i == 0:
|
|
v = variable_scope.get_variable("v", [])
|
|
self.assertEquals("foo/v:0", v.name)
|
|
else:
|
|
# Any thread after the first one should fail to create variable
|
|
# with the same name.
|
|
with self.assertRaises(ValueError):
|
|
variable_scope.get_variable("v", [])
|
|
pause_event.set()
|
|
run_event.wait()
|
|
|
|
graph = ops.get_default_graph()
|
|
run_events = [threading.Event() for _ in range(2)]
|
|
pause_events = [threading.Event() for _ in range(2)]
|
|
threads = [
|
|
threading.Thread(
|
|
target=thread_fn, args=(i, graph, run_events[i], pause_events[i]))
|
|
for i in range(2)
|
|
]
|
|
|
|
# Start first thread.
|
|
threads[0].start()
|
|
pause_events[0].wait()
|
|
# Start next thread once the first thread has paused.
|
|
threads[1].start()
|
|
pause_events[1].wait()
|
|
# Resume both threads.
|
|
run_events[0].set()
|
|
run_events[1].set()
|
|
threads[0].join()
|
|
threads[1].join()
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@run_inside_wrap_function_in_eager_mode
|
|
def testReenterMainScope(self):
|
|
|
|
def thread_fn(graph, main_thread_scope):
|
|
with graph.as_default():
|
|
# Variable created with main scope will have prefix "main".
|
|
with variable_scope.variable_scope(main_thread_scope):
|
|
with variable_scope.variable_scope("foo"):
|
|
v = variable_scope.get_variable("v", [])
|
|
self.assertEquals("main/foo/v:0", v.name)
|
|
|
|
# Variable created outside main scope will not have prefix "main".
|
|
with variable_scope.variable_scope("bar"):
|
|
v = variable_scope.get_variable("v", [])
|
|
self.assertEquals("bar/v:0", v.name)
|
|
|
|
graph = ops.get_default_graph()
|
|
with variable_scope.variable_scope("main") as main_thread_scope:
|
|
thread = threading.Thread(
|
|
target=thread_fn, args=(graph, main_thread_scope))
|
|
thread.start()
|
|
thread.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|