924 lines
35 KiB
Python
924 lines
35 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tf.py."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
import operator
|
|
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors_impl
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import gen_state_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import random_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.training import gradient_descent
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
class VariablesTestCase(test.TestCase, parameterized.TestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testDistributeStrategy(self):
|
|
v = variables.VariableV1(0.0)
|
|
self.assertIsNone(v._distribute_strategy)
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testInitialization(self):
|
|
with self.cached_session():
|
|
var0 = variables.VariableV1(0.0)
|
|
self.assertEqual("Variable:0", var0.name)
|
|
self.assertEqual("Variable", var0._shared_name)
|
|
self.assertEqual([], var0.get_shape())
|
|
self.assertEqual([], var0.get_shape())
|
|
self.assertEqual([], var0.shape)
|
|
|
|
var1 = variables.VariableV1(1.1)
|
|
self.assertEqual("Variable_1:0", var1.name)
|
|
self.assertEqual("Variable_1", var1._shared_name)
|
|
self.assertEqual([], var1.get_shape())
|
|
self.assertEqual([], var1.get_shape())
|
|
self.assertEqual([], var1.shape)
|
|
|
|
with self.assertRaisesOpError("Attempting to use uninitialized value"):
|
|
self.evaluate(var0)
|
|
|
|
with self.assertRaisesOpError("Attempting to use uninitialized value"):
|
|
self.evaluate(var1)
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
|
|
self.assertAllClose(0.0, self.evaluate(var0))
|
|
self.assertAllClose(1.1, self.evaluate(var1))
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testInitializationOrder(self):
|
|
with self.cached_session():
|
|
rnd = variables.Variable(random_ops.random_uniform([3, 6]), name="rnd")
|
|
self.assertEqual("rnd:0", rnd.name)
|
|
self.assertEqual([3, 6], rnd.get_shape())
|
|
self.assertEqual([3, 6], rnd.get_shape())
|
|
self.assertEqual([3, 6], rnd.shape)
|
|
|
|
dep = variables.Variable(rnd.initialized_value(), name="dep")
|
|
self.assertEqual("dep:0", dep.name)
|
|
self.assertEqual([3, 6], dep.get_shape())
|
|
self.assertEqual([3, 6], dep.get_shape())
|
|
self.assertEqual([3, 6], dep.shape)
|
|
|
|
# Currently have to set the shape manually for Add.
|
|
added_val = rnd.initialized_value() + dep.initialized_value() + 2.0
|
|
added_val.set_shape(rnd.get_shape())
|
|
|
|
depdep = variables.Variable(added_val, name="depdep")
|
|
self.assertEqual("depdep:0", depdep.name)
|
|
self.assertEqual([3, 6], depdep.get_shape())
|
|
self.assertEqual([3, 6], depdep.get_shape())
|
|
self.assertEqual([3, 6], depdep.shape)
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
|
|
self.assertAllClose(self.evaluate(rnd), self.evaluate(dep))
|
|
self.assertAllClose(
|
|
self.evaluate(rnd) + self.evaluate(dep) + 2.0, self.evaluate(depdep))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCyclicInitializer(self):
|
|
with self.cached_session():
|
|
cyclic = control_flow_ops.while_loop(
|
|
cond=lambda i: i < 10,
|
|
body=lambda i: i + 1,
|
|
loop_vars=(constant_op.constant(0),))
|
|
initial_value = variables._try_guard_against_uninitialized_dependencies(
|
|
"test", cyclic)
|
|
self.assertIs(initial_value, cyclic)
|
|
|
|
def testIterable(self):
|
|
with self.assertRaisesRegexp(TypeError, "not iterable"):
|
|
for _ in variables.Variable(0.0):
|
|
pass
|
|
with self.assertRaisesRegexp(TypeError, "not iterable"):
|
|
for _ in variables.Variable([0.0, 1.0]):
|
|
pass
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testAssignments(self):
|
|
with self.cached_session():
|
|
var = variables.Variable(0.0)
|
|
plus_one = var.assign_add(1.0)
|
|
minus_one = var.assign_sub(2.0)
|
|
four = var.assign(4.0)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllClose(0.0, self.evaluate(var))
|
|
|
|
self.assertAllClose(1.0, self.evaluate(plus_one))
|
|
self.assertAllClose(1.0, self.evaluate(var))
|
|
|
|
self.assertAllClose(-1.0, self.evaluate(minus_one))
|
|
self.assertAllClose(-1.0, self.evaluate(var))
|
|
|
|
self.assertAllClose(4.0, self.evaluate(four))
|
|
self.assertAllClose(4.0, self.evaluate(var))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testResourceAssignments(self):
|
|
with self.session(use_gpu=True):
|
|
var = resource_variable_ops.ResourceVariable(0.0)
|
|
plus_one = var.assign_add(1.0)
|
|
minus_one = var.assign_sub(2.0)
|
|
four = var.assign(4.0)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllClose(0.0, self.evaluate(var))
|
|
|
|
self.evaluate(plus_one)
|
|
self.assertAllClose(1.0, self.evaluate(var))
|
|
|
|
self.evaluate(minus_one)
|
|
self.assertAllClose(-1.0, self.evaluate(var))
|
|
|
|
self.evaluate(four)
|
|
self.assertAllClose(4.0, self.evaluate(var))
|
|
|
|
def testAssignDifferentShapesEagerNotAllowed(self):
|
|
with context.eager_mode():
|
|
var = variables.Variable(np.zeros(shape=[1, 1]))
|
|
with self.assertRaisesRegexp(ValueError,
|
|
"Shapes.*and.*are incompatible"):
|
|
var.assign(np.zeros(shape=[2, 2]))
|
|
|
|
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testAssignDifferentShapesAllowed(self):
|
|
var = variables.Variable(np.zeros(shape=[1, 1]),
|
|
shape=tensor_shape.TensorShape(None))
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllEqual(np.zeros(shape=[1, 1]), var.read_value())
|
|
self.evaluate(var.assign(np.zeros(shape=[2, 2])))
|
|
self.assertAllEqual(np.zeros(shape=[2, 2]), var.read_value())
|
|
|
|
@test_util.disable_tfrt("GetHostSize() is not expected to be called with "
|
|
"string type. b/156761465")
|
|
def testZeroSizeStringAssign(self):
|
|
with self.cached_session() as sess:
|
|
array = variables.VariableV1(
|
|
initial_value=array_ops.zeros((0,), dtype=dtypes.string),
|
|
name="foo",
|
|
trainable=False,
|
|
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
|
self.evaluate(variables.local_variables_initializer())
|
|
old_value = array.value()
|
|
copy_op = array.assign(old_value)
|
|
self.assertEqual([], list(self.evaluate(copy_op)))
|
|
|
|
def _countUpToTest(self, dtype):
|
|
with self.cached_session():
|
|
zero = constant_op.constant(0, dtype=dtype)
|
|
var = variables.Variable(zero)
|
|
count_up_to = var.count_up_to(3)
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertEqual(0, self.evaluate(var))
|
|
|
|
self.assertEqual(0, self.evaluate(count_up_to))
|
|
self.assertEqual(1, self.evaluate(var))
|
|
|
|
self.assertEqual(1, self.evaluate(count_up_to))
|
|
self.assertEqual(2, self.evaluate(var))
|
|
|
|
self.assertEqual(2, self.evaluate(count_up_to))
|
|
self.assertEqual(3, self.evaluate(var))
|
|
|
|
with self.assertRaisesOpError("Reached limit of 3"):
|
|
self.evaluate(count_up_to)
|
|
self.assertEqual(3, self.evaluate(var))
|
|
|
|
with self.assertRaisesOpError("Reached limit of 3"):
|
|
self.evaluate(count_up_to)
|
|
self.assertEqual(3, self.evaluate(var))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCountUpToInt32(self):
|
|
self._countUpToTest(dtypes.int32)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCountUpToInt64(self):
|
|
self._countUpToTest(dtypes.int64)
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testControlDepsNone(self):
|
|
with self.cached_session():
|
|
c = constant_op.constant(1.0)
|
|
with ops.control_dependencies([c]):
|
|
# d get the control dep.
|
|
d = constant_op.constant(2.0)
|
|
# variables do not.
|
|
var_x = variables.VariableV1(2.0)
|
|
self.assertEqual([c.op], d.op.control_inputs)
|
|
self.assertEqual([], var_x.initializer.control_inputs)
|
|
self.assertEqual([], var_x.value().op.control_inputs)
|
|
self.assertEqual([], var_x._ref().op.control_inputs) # pylint: disable=protected-access
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testControlFlow(self):
|
|
with self.cached_session() as sess:
|
|
v0 = variables.Variable(0, name="v0")
|
|
var_dict = {}
|
|
|
|
# Call get_variable in each of the cond clauses.
|
|
def var_in_then_clause():
|
|
v1 = variables.Variable(1, name="v1")
|
|
var_dict["v1"] = v1
|
|
return v1 + v0
|
|
|
|
def var_in_else_clause():
|
|
v2 = variables.Variable(2, name="v2")
|
|
var_dict["v2"] = v2
|
|
return v2 + v0
|
|
|
|
add = control_flow_ops.cond(
|
|
math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause)
|
|
v1 = var_dict["v1"]
|
|
v2 = var_dict["v2"]
|
|
# We should be able to initialize and run v1 and v2 without initializing
|
|
# v0, even if the variable was created with a control dep on v0.
|
|
self.evaluate(v1.initializer)
|
|
self.assertEqual([1], self.evaluate(v1))
|
|
self.evaluate(v2.initializer)
|
|
self.assertEqual([2], self.evaluate(v2))
|
|
# v0 should still be uninitialized.
|
|
with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
|
|
self.evaluate(v0)
|
|
# We should not be able to run 'add' yet.
|
|
with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
|
|
self.evaluate(add)
|
|
# If we initialize v0 we should be able to run 'add'.
|
|
self.evaluate(v0.initializer)
|
|
self.evaluate(add)
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testControlFlowInitialization(self):
|
|
"""Expects an error if an initializer is in a control-flow scope."""
|
|
def cond(i, _):
|
|
return i < 10
|
|
|
|
def body(i, _):
|
|
zero = array_ops.zeros([], dtype=dtypes.int32)
|
|
v = variables.Variable(initial_value=zero)
|
|
return (i + 1, v.read_value())
|
|
|
|
with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
|
|
control_flow_ops.while_loop(cond, body, [0, 0])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testUseVariableAsTensor(self):
|
|
with self.cached_session():
|
|
var_x = variables.Variable(2.0)
|
|
var_y = variables.Variable(3.0)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllClose(2.0, self.evaluate(var_x))
|
|
self.assertAllClose(3.0, self.evaluate(var_y))
|
|
self.assertAllClose(5.0, self.evaluate(math_ops.add(var_x, var_y)))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testZeroSizeVarSameAsConst(self):
|
|
with self.cached_session():
|
|
zero_size_var = variables.Variable(array_ops.zeros([0, 2]))
|
|
zero_size_const = array_ops.ones([2, 0])
|
|
variable_mul = math_ops.matmul(zero_size_const, zero_size_var)
|
|
const_mul = math_ops.matmul(
|
|
zero_size_const, zero_size_const, transpose_b=True)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
variable_output = self.evaluate(variable_mul)
|
|
self.assertAllClose(self.evaluate(const_mul), variable_output)
|
|
self.assertAllClose([[0., 0.], [0., 0.]], variable_output)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCachingDevice(self):
|
|
with self.cached_session():
|
|
var = variables.Variable(2.0)
|
|
self.assertEqual(var.device, var.initialized_value().device)
|
|
|
|
var_cached = variables.Variable(2.0, caching_device="/job:foo")
|
|
self.assertFalse(var_cached.device.startswith("/job:foo"))
|
|
self.assertTrue(var_cached.value().device.startswith("/job:foo"))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCollections(self):
|
|
with self.cached_session():
|
|
var_x = variables.VariableV1(2.0)
|
|
var_y = variables.VariableV1(2.0, trainable=False)
|
|
var_z = variables.VariableV1(2.0, trainable=True)
|
|
var_t = variables.VariableV1(
|
|
2.0,
|
|
trainable=True,
|
|
collections=[
|
|
ops.GraphKeys.TRAINABLE_VARIABLES, ops.GraphKeys.GLOBAL_VARIABLES
|
|
])
|
|
self.assertEqual([var_x, var_y, var_z, var_t],
|
|
variables.global_variables())
|
|
self.assertEqual([var_x, var_z, var_t], variables.trainable_variables())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCollectionsWithScope(self):
|
|
with self.cached_session():
|
|
with ops.name_scope("scope_1"):
|
|
var_x = variables.VariableV1(2.0)
|
|
with ops.name_scope("scope_2"):
|
|
var_y = variables.VariableV1(2.0)
|
|
|
|
self.assertEqual([var_x, var_y], variables.global_variables())
|
|
self.assertEqual([var_x], variables.global_variables("scope_1"))
|
|
self.assertEqual([var_y], variables.global_variables("scope_2"))
|
|
|
|
self.assertEqual([var_x, var_y], variables.trainable_variables())
|
|
self.assertEqual([var_x], variables.trainable_variables("scope_1"))
|
|
self.assertEqual([var_y], variables.trainable_variables("scope_2"))
|
|
|
|
def testOperatorWrapping(self):
|
|
for attr in functools.WRAPPER_ASSIGNMENTS:
|
|
self.assertEqual(
|
|
getattr(variables.Variable.__add__, attr),
|
|
getattr(ops.Tensor.__add__, attr))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testOperators(self):
|
|
with self.cached_session():
|
|
var_f = variables.Variable([2.0])
|
|
add = var_f + 0.0
|
|
radd = 1.0 + var_f
|
|
sub = var_f - 1.0
|
|
rsub = 1.0 - var_f
|
|
mul = var_f * 10.0
|
|
rmul = 10.0 * var_f
|
|
div = var_f / 10.0
|
|
rdiv = 10.0 / var_f
|
|
lt = var_f < 3.0
|
|
rlt = 3.0 < var_f
|
|
le = var_f <= 2.0
|
|
rle = 2.0 <= var_f
|
|
gt = var_f > 3.0
|
|
rgt = 3.0 > var_f
|
|
ge = var_f >= 2.0
|
|
rge = 2.0 >= var_f
|
|
neg = -var_f
|
|
abs_v = abs(var_f)
|
|
|
|
var_i = variables.Variable([20])
|
|
mod = var_i % 7
|
|
rmod = 103 % var_i
|
|
|
|
var_b = variables.Variable([True, False])
|
|
and_v = operator.and_(var_b, [True, True])
|
|
or_v = operator.or_(var_b, [False, True])
|
|
xor_v = operator.xor(var_b, [False, False])
|
|
invert_v = ~var_b
|
|
|
|
rnd = np.random.rand(4, 4).astype("f")
|
|
var_t = variables.Variable(rnd)
|
|
slice_v = var_t[2, 0:0]
|
|
|
|
var_m = variables.Variable([[2.0, 3.0]])
|
|
matmul = var_m.__matmul__([[10.0], [20.0]])
|
|
rmatmul = var_m.__rmatmul__([[10.0], [20.0]])
|
|
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllClose([2.0], self.evaluate(add))
|
|
self.assertAllClose([3.0], self.evaluate(radd))
|
|
self.assertAllClose([1.0], self.evaluate(sub))
|
|
self.assertAllClose([-1.0], self.evaluate(rsub))
|
|
self.assertAllClose([20.0], self.evaluate(mul))
|
|
self.assertAllClose([20.0], self.evaluate(rmul))
|
|
self.assertAllClose([0.2], self.evaluate(div))
|
|
self.assertAllClose([5.0], self.evaluate(rdiv))
|
|
self.assertAllClose([-2.0], self.evaluate(neg))
|
|
self.assertAllClose([2.0], self.evaluate(abs_v))
|
|
self.assertAllClose([True], self.evaluate(lt))
|
|
self.assertAllClose([False], self.evaluate(rlt))
|
|
self.assertAllClose([True], self.evaluate(le))
|
|
self.assertAllClose([True], self.evaluate(rle))
|
|
self.assertAllClose([False], self.evaluate(gt))
|
|
self.assertAllClose([True], self.evaluate(rgt))
|
|
self.assertAllClose([True], self.evaluate(ge))
|
|
self.assertAllClose([True], self.evaluate(rge))
|
|
|
|
self.assertAllClose([6], self.evaluate(mod))
|
|
self.assertAllClose([3], self.evaluate(rmod))
|
|
|
|
self.assertAllClose([True, False], self.evaluate(and_v))
|
|
self.assertAllClose([True, True], self.evaluate(or_v))
|
|
self.assertAllClose([True, False], self.evaluate(xor_v))
|
|
self.assertAllClose([False, True], self.evaluate(invert_v))
|
|
|
|
self.assertAllClose(rnd[2, 0:0], self.evaluate(slice_v))
|
|
|
|
self.assertAllClose([[80.0]], self.evaluate(matmul))
|
|
self.assertAllClose([[20.0, 30.0], [40.0, 60.0]], self.evaluate(rmatmul))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSession(self):
|
|
with self.cached_session() as sess:
|
|
var = variables.Variable([1, 12])
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllClose([1, 12], self.evaluate(var))
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testColocation(self):
|
|
with ops.device("/job:ps"):
|
|
var = variables.VariableV1(0, name="v")
|
|
with ops.device("/job:worker/task:7"):
|
|
assign_op = var.assign(1)
|
|
self.assertDeviceEqual("/job:ps", assign_op.device)
|
|
self.assertEqual([b"loc:@v"], assign_op.op.colocation_groups())
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testInitializerFunction(self):
|
|
value = [[-42], [133.7]]
|
|
shape = [2, 1]
|
|
with self.cached_session():
|
|
initializer = lambda: constant_op.constant(value)
|
|
|
|
v1 = variables.Variable(initializer, dtype=dtypes.float32)
|
|
self.assertEqual(shape, v1.get_shape())
|
|
self.assertEqual(shape, v1.shape)
|
|
self.assertAllClose(value, self.evaluate(v1.initial_value))
|
|
with self.assertRaises(errors_impl.FailedPreconditionError):
|
|
self.evaluate(v1)
|
|
|
|
v2 = variables.Variable(
|
|
math_ops.negative(v1.initialized_value()), dtype=dtypes.float32)
|
|
self.assertEqual(v1.get_shape(), v2.get_shape())
|
|
self.assertEqual(v1.shape, v2.shape)
|
|
self.assertAllClose(np.negative(value), self.evaluate(v2.initial_value))
|
|
|
|
with self.assertRaises(errors_impl.FailedPreconditionError):
|
|
self.evaluate(v2)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllClose(np.negative(value), self.evaluate(v2))
|
|
|
|
def testConstraintArg(self):
|
|
constraint = lambda x: x
|
|
v = variables.Variable(
|
|
lambda: constant_op.constant(1.),
|
|
constraint=constraint)
|
|
self.assertEqual(v.constraint, constraint)
|
|
|
|
constraint = 0
|
|
with self.assertRaises(ValueError):
|
|
v = variables.Variable(
|
|
lambda: constant_op.constant(1.),
|
|
constraint=constraint)
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testNoRefDataRace(self):
|
|
with self.cached_session():
|
|
a = variables.Variable([1, 2, 3], dtype=dtypes.float32)
|
|
b = variables.Variable(a.initialized_value() + 2)
|
|
c = variables.Variable(b.initialized_value() + 2)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertAllEqual(self.evaluate(a), [1, 2, 3])
|
|
self.assertAllEqual(self.evaluate(b), [3, 4, 5])
|
|
self.assertAllEqual(self.evaluate(c), [5, 6, 7])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInitializerFunctionDevicePlacement(self):
|
|
with self.cached_session():
|
|
initializer = lambda: constant_op.constant(42.0)
|
|
with ops.device("/cpu:100"):
|
|
v1 = variables.Variable(initializer, dtype=dtypes.float32, name="v1")
|
|
expected_device = "/device:CPU:100"
|
|
expected_group_v1 = [b"loc:@v1"]
|
|
self.assertEqual(expected_device, v1.op.device)
|
|
self.assertEqual(expected_group_v1, v1.op.colocation_groups())
|
|
for i in v1.initializer.inputs:
|
|
self.assertEqual(expected_group_v1, i.op.colocation_groups())
|
|
|
|
v2 = variables.Variable(initializer, dtype=dtypes.float32, name="v2")
|
|
expected_group_v2 = [b"loc:@v2"]
|
|
self.assertEqual(expected_group_v2, v2.op.colocation_groups())
|
|
for i in v2.initializer.inputs:
|
|
self.assertEqual(expected_group_v2, i.op.colocation_groups())
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testVariableDefInitializedInstances(self):
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
v_def = variables.Variable(
|
|
initial_value=constant_op.constant(3.0)).to_proto()
|
|
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
# v describes a VariableDef-based variable without an initial value.
|
|
v = variables.Variable(variable_def=v_def)
|
|
self.assertEqual(3.0, self.evaluate(v.initialized_value()))
|
|
|
|
# initialized_value should not rerun the initializer_op if the variable
|
|
# has already been initialized elsewhere.
|
|
self.evaluate(v.assign(1.0))
|
|
self.assertEqual(1.0, self.evaluate(v.initialized_value()))
|
|
|
|
v_def.ClearField("initial_value_name")
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
# Restoring a legacy VariableDef proto that does not have
|
|
# initial_value_name set should still work.
|
|
v = variables.Variable(variable_def=v_def)
|
|
# We should also be able to re-export the variable to a new meta graph.
|
|
self.assertProtoEquals(v_def, v.to_proto())
|
|
# But attempts to use initialized_value will result in errors.
|
|
with self.assertRaises(ValueError):
|
|
self.evaluate(v.initialized_value())
|
|
|
|
def testTrainableInProto(self):
|
|
with ops.Graph().as_default():
|
|
non_trainable_variable = variables.Variable(
|
|
trainable=False,
|
|
initial_value=constant_op.constant(10.0))
|
|
self.assertEqual(
|
|
False,
|
|
variables.Variable(variable_def=non_trainable_variable.to_proto())
|
|
.trainable)
|
|
trainable_variable = variables.Variable(
|
|
trainable=True,
|
|
initial_value=constant_op.constant(10.0))
|
|
self.assertEqual(
|
|
True,
|
|
variables.Variable(variable_def=trainable_variable.to_proto())
|
|
.trainable)
|
|
|
|
def testSynchronizationAndAggregationSaved(self):
|
|
with ops.Graph().as_default():
|
|
original_variable = variables.Variable(
|
|
initial_value=constant_op.constant(10.0),
|
|
synchronization=variables.VariableSynchronization.NONE,
|
|
aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
|
|
self.assertEqual(variables.VariableSynchronization.NONE,
|
|
original_variable.synchronization)
|
|
self.assertEqual(variables.VariableAggregation.ONLY_FIRST_REPLICA,
|
|
original_variable.aggregation)
|
|
|
|
laundered = variables.Variable(
|
|
variable_def=original_variable.to_proto())
|
|
self.assertEqual(
|
|
variables.VariableSynchronization.NONE,
|
|
laundered.synchronization)
|
|
self.assertEqual(variables.VariableAggregationV2.ONLY_FIRST_REPLICA,
|
|
laundered.aggregation)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testLoad(self):
|
|
with self.cached_session():
|
|
var = variables.Variable(np.zeros((5, 5), np.float32))
|
|
self.evaluate(variables.global_variables_initializer())
|
|
var.load(np.ones((5, 5), np.float32))
|
|
|
|
self.assertAllClose(np.ones((5, 5), np.float32), self.evaluate(var))
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testRepr(self):
|
|
var = variables.VariableV1(np.zeros((5, 5), np.float32), name="noop")
|
|
self.assertEqual(
|
|
"<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>",
|
|
repr(var))
|
|
|
|
def testVariableNamesPreserveNameScopesWithDefun(self):
|
|
@function.defun
|
|
def create_variable():
|
|
with ops.name_scope("foo"):
|
|
v = variables.Variable(0.0, name="bar")
|
|
self.assertEqual(v.name, "foo/bar:0")
|
|
with ops.get_default_graph().as_default():
|
|
create_variable()
|
|
|
|
@parameterized.parameters(variables.VariableV1, variables.Variable)
|
|
def testTrainableVariable(self, cls):
|
|
v1 = cls(1.0)
|
|
self.assertEqual(True, v1.trainable)
|
|
|
|
v2 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ)
|
|
self.assertEqual(False, v2.trainable)
|
|
|
|
v3 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ,
|
|
trainable=True)
|
|
self.assertEqual(True, v3.trainable)
|
|
|
|
v4 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ,
|
|
trainable=False)
|
|
self.assertEqual(False, v4.trainable)
|
|
|
|
|
|
class IsInitializedTest(test.TestCase):
|
|
|
|
def testNoVars(self):
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
uninited = variables.report_uninitialized_variables()
|
|
self.assertEqual(0, self.evaluate(uninited).size)
|
|
|
|
def testAssertVariablesInitialized(self):
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
v = variables.Variable([1, 2], name="v")
|
|
w = variables.Variable([3, 4], name="w")
|
|
_ = v, w
|
|
uninited = variables.report_uninitialized_variables()
|
|
self.assertAllEqual(np.array([b"v", b"w"]), self.evaluate(uninited))
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.assertEqual(0, self.evaluate(uninited).size)
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testVariableList(self):
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
v = variables.VariableV1([1, 2], name="v")
|
|
w = variables.VariableV1([3, 4], name="w")
|
|
uninited = variables.report_uninitialized_variables()
|
|
self.assertAllEqual(np.array([b"v", b"w"]), self.evaluate(uninited))
|
|
self.evaluate(w.initializer)
|
|
self.assertAllEqual(np.array([b"v"]), self.evaluate(uninited))
|
|
v.initializer.run()
|
|
self.assertEqual(0, self.evaluate(uninited).size)
|
|
|
|
def testZeroSizeVarInitialized(self):
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
v = variables.Variable(array_ops.zeros([0, 2]), name="v")
|
|
uninited = variables.report_uninitialized_variables()
|
|
v.initializer.run() # not strictly necessary
|
|
self.assertEqual(0, self.evaluate(uninited).size)
|
|
|
|
def testTrainingWithZeroSizeVar(self):
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
a = variables.Variable(array_ops.zeros([0, 2]))
|
|
b = variables.Variable(array_ops.ones([2, 2]))
|
|
objective = math_ops.reduce_sum(b + math_ops.matmul(
|
|
a, a, transpose_a=True))
|
|
self.evaluate(variables.global_variables_initializer())
|
|
do_opt = gradient_descent.GradientDescentOptimizer(0.1).minimize(
|
|
objective)
|
|
self.evaluate([do_opt])
|
|
self.assertAllClose([[0.9, 0.9], [0.9, 0.9]], self.evaluate(b))
|
|
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
class ObsoleteIsInitializedTest(test.TestCase):
|
|
|
|
def testNoVars(self):
|
|
with ops.Graph().as_default():
|
|
self.assertEqual(None, variables.assert_variables_initialized())
|
|
|
|
def testVariables(self):
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
v = variables.VariableV1([1, 2])
|
|
w = variables.VariableV1([3, 4])
|
|
_ = v, w
|
|
inited = variables.assert_variables_initialized()
|
|
with self.assertRaisesOpError("Attempting to use uninitialized value"):
|
|
self.evaluate(inited)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
self.evaluate(inited)
|
|
|
|
def testVariableList(self):
|
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
|
v = variables.VariableV1([1, 2])
|
|
w = variables.VariableV1([3, 4])
|
|
inited = variables.assert_variables_initialized([v])
|
|
with self.assertRaisesOpError("Attempting to use uninitialized value"):
|
|
inited.op.run()
|
|
self.evaluate(w.initializer)
|
|
with self.assertRaisesOpError("Attempting to use uninitialized value"):
|
|
inited.op.run()
|
|
v.initializer.run()
|
|
inited.op.run()
|
|
|
|
|
|
class PartitionedVariableTest(test.TestCase):
|
|
|
|
def testPartitionedVariable(self):
|
|
with ops.Graph().as_default():
|
|
v0 = variables.Variable([0])
|
|
v1 = variables.Variable([1])
|
|
v0._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
|
|
v1._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
|
|
partitions = [2]
|
|
|
|
# Pass variable_list as [v1, v0] to ensure they are properly
|
|
# re-sorted to [v0, v1] based on their slice info offsets.
|
|
partitioned_variable = variables.PartitionedVariable(
|
|
name="two_vars",
|
|
shape=[2],
|
|
dtype=v0.dtype,
|
|
variable_list=[v1, v0],
|
|
partitions=partitions)
|
|
|
|
concatenated = ops.convert_to_tensor(partitioned_variable)
|
|
num_partitions = len(partitioned_variable)
|
|
iterated_partitions = list(partitioned_variable)
|
|
self.assertEqual(2, num_partitions)
|
|
self.assertEqual([v0, v1], iterated_partitions)
|
|
self.assertEqual([2], partitioned_variable.get_shape())
|
|
self.assertEqual([2], partitioned_variable.shape)
|
|
self.assertEqual([2], concatenated.get_shape())
|
|
self.assertEqual([2], concatenated.shape)
|
|
|
|
def testPartitionedVariableFailures(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegexp(ValueError, "empty"):
|
|
variables.PartitionedVariable(
|
|
name="fail",
|
|
shape=2,
|
|
dtype=dtypes.int32,
|
|
variable_list=[],
|
|
partitions=[])
|
|
|
|
with self.assertRaisesRegexp(ValueError, "must have a save_slice_info"):
|
|
v0 = variables.Variable([0])
|
|
partitions = [1]
|
|
variables.PartitionedVariable(
|
|
name="two_vars",
|
|
shape=[1],
|
|
dtype=v0.dtype,
|
|
variable_list=[v0],
|
|
partitions=partitions)
|
|
|
|
with self.assertRaisesRegexp(ValueError, "full shapes must match"):
|
|
v0 = variables.Variable([0])
|
|
v1 = variables.Variable([1])
|
|
v0._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
|
|
v1._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
|
|
partitions = [2]
|
|
|
|
variables.PartitionedVariable(
|
|
name="two_vars",
|
|
shape=[3],
|
|
dtype=v0.dtype,
|
|
variable_list=[v1, v0],
|
|
partitions=partitions)
|
|
|
|
with self.assertRaisesRegexp(ValueError, "must be positive"):
|
|
v0 = variables.Variable([0])
|
|
v0._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
|
|
partitions = [0]
|
|
|
|
variables.PartitionedVariable(
|
|
name="two_vars",
|
|
shape=[2],
|
|
dtype=v0.dtype,
|
|
variable_list=[v0],
|
|
partitions=partitions)
|
|
|
|
def testPartitionedVariableAssignments(self):
|
|
with ops.Graph().as_default(), self.cached_session():
|
|
v0 = variables.Variable(initial_value=[0.0])
|
|
v1 = variables.Variable(initial_value=[1.0])
|
|
v2 = variables.Variable(initial_value=[20.0])
|
|
v3 = variables.Variable(initial_value=[30.0])
|
|
v0._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
|
|
v1._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v1.name, [2], [1], [1]))
|
|
v2._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v2.name, [2], [0], [1]))
|
|
v3._set_save_slice_info(
|
|
variables.Variable.SaveSliceInfo(v3.name, [2], [1], [1]))
|
|
|
|
partitions = [2]
|
|
|
|
# Pass variable_list as [v1, v0] to ensure they are properly
|
|
# re-sorted to [v0, v1] based on their slice info offsets.
|
|
pv_0 = variables.PartitionedVariable(
|
|
name="two_vars",
|
|
shape=[2],
|
|
dtype=v0.dtype,
|
|
variable_list=[v0, v1],
|
|
partitions=partitions)
|
|
|
|
pv_1 = variables.PartitionedVariable(
|
|
name="two_vars",
|
|
shape=[2],
|
|
dtype=v0.dtype,
|
|
variable_list=[v2, v3],
|
|
partitions=partitions)
|
|
|
|
deltas_a = constant_op.constant([1.0, 2.0])
|
|
deltas_b = constant_op.constant([3.0, 4.0])
|
|
ones = array_ops.ones([2])
|
|
plus_delta = pv_0.assign_add(deltas_a)
|
|
minus_delta = pv_0.assign_sub(deltas_b)
|
|
assign_ones = pv_0.assign(ones)
|
|
|
|
c_0 = constant_op.constant([2.0])
|
|
c_1 = constant_op.constant([3.0])
|
|
assign_list = pv_1.assign([c_0, c_1])
|
|
assign_part_value = pv_1.assign_add(assign_ones)
|
|
assign_part_var = pv_1.assign_sub(pv_0)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
|
|
self.assertEqual([1.0], self.evaluate(plus_delta[0]))
|
|
self.assertEqual([1.0], self.evaluate(v0))
|
|
self.assertEqual([3.0], self.evaluate(plus_delta[1]))
|
|
self.assertEqual([3.0], self.evaluate(v1))
|
|
|
|
self.assertEqual([-2.0], self.evaluate(minus_delta[0]))
|
|
self.assertEqual([-2.0], self.evaluate(v0))
|
|
self.assertEqual([-1.0], self.evaluate(minus_delta[1]))
|
|
self.assertEqual([-1.0], self.evaluate(v1))
|
|
|
|
self.assertEqual([1.0], self.evaluate(assign_ones[0]))
|
|
self.assertEqual([1.0], self.evaluate(v0))
|
|
self.assertEqual([1.0], self.evaluate(assign_ones[1]))
|
|
self.assertEqual([1.0], self.evaluate(v1))
|
|
|
|
self.assertEqual([2.0], self.evaluate(assign_list[0]))
|
|
self.assertEqual([2.0], self.evaluate(v2))
|
|
self.assertEqual([3.0], self.evaluate(assign_list[1]))
|
|
self.assertEqual([3.0], self.evaluate(v3))
|
|
|
|
self.assertEqual([3.0], self.evaluate(assign_part_value[0]))
|
|
self.assertEqual([3.0], self.evaluate(v2))
|
|
self.assertEqual([4.0], self.evaluate(assign_part_value[1]))
|
|
self.assertEqual([4.0], self.evaluate(v3))
|
|
|
|
self.assertEqual([2.0], self.evaluate(assign_part_var[0]))
|
|
self.assertEqual([2.0], self.evaluate(v2))
|
|
self.assertEqual([3.0], self.evaluate(assign_part_var[1]))
|
|
self.assertEqual([3.0], self.evaluate(v3))
|
|
|
|
|
|
class VariableContainerTest(test.TestCase):
|
|
|
|
def testContainer(self):
|
|
with ops.Graph().as_default():
|
|
v0 = variables.Variable([0])
|
|
with ops.container("l1"):
|
|
v1 = variables.Variable([1])
|
|
with ops.container("l2"):
|
|
v2 = variables.Variable([2])
|
|
special_v = gen_state_ops.variable(
|
|
shape=[1],
|
|
dtype=dtypes.float32,
|
|
name="VariableInL3",
|
|
container="l3",
|
|
shared_name="")
|
|
v3 = variables.Variable([3])
|
|
v4 = variables.Variable([4])
|
|
self.assertEqual(compat.as_bytes(""), v0.op.get_attr("container"))
|
|
self.assertEqual(compat.as_bytes("l1"), v1.op.get_attr("container"))
|
|
self.assertEqual(compat.as_bytes("l2"), v2.op.get_attr("container"))
|
|
self.assertEqual(compat.as_bytes("l3"), special_v.op.get_attr("container"))
|
|
self.assertEqual(compat.as_bytes("l1"), v3.op.get_attr("container"))
|
|
self.assertEqual(compat.as_bytes(""), v4.op.get_attr("container"))
|
|
|
|
|
|
class AggregationModesTest(test.TestCase):
|
|
|
|
def testV1V2Equal(self):
|
|
v1 = variables.VariableAggregation
|
|
v2 = variables.VariableAggregationV2
|
|
|
|
self.assertEqual(v1.NONE, v2.NONE)
|
|
self.assertEqual(v1.SUM, v2.SUM)
|
|
self.assertEqual(v1.MEAN, v2.MEAN)
|
|
self.assertEqual(v1.ONLY_FIRST_REPLICA, v2.ONLY_FIRST_REPLICA)
|
|
self.assertEqual(v1.ONLY_FIRST_TOWER, v2.ONLY_FIRST_REPLICA)
|
|
|
|
self.assertEqual(v2.NONE, v1.NONE)
|
|
self.assertEqual(v2.SUM, v1.SUM)
|
|
self.assertEqual(v2.MEAN, v1.MEAN)
|
|
self.assertEqual(v2.ONLY_FIRST_REPLICA, v1.ONLY_FIRST_REPLICA)
|
|
self.assertEqual(v2.ONLY_FIRST_REPLICA, v1.ONLY_FIRST_TOWER)
|
|
|
|
self.assertEqual(hash(v1.NONE), hash(v2.NONE))
|
|
self.assertEqual(hash(v1.SUM), hash(v2.SUM))
|
|
self.assertEqual(hash(v1.MEAN), hash(v2.MEAN))
|
|
self.assertEqual(hash(v1.ONLY_FIRST_REPLICA), hash(v2.ONLY_FIRST_REPLICA))
|
|
self.assertEqual(hash(v1.ONLY_FIRST_TOWER), hash(v2.ONLY_FIRST_REPLICA))
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|