From a861238b039d890cb908767a06f013e637b26f84 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 3 Nov 2020 11:13:26 -0800 Subject: [PATCH] Fork keras related sharded_variable_test to keras/distribute. TF distribute shouldn't rely on any keras code. PiperOrigin-RevId: 340483958 Change-Id: I4c3774dce1e914dc1f257d13117420a3fb9b3406 --- tensorflow/python/distribute/BUILD | 4 - .../distribute/sharded_variable_test.py | 79 ------------- tensorflow/python/keras/distribute/BUILD | 15 +++ .../keras/distribute/sharded_variable_test.py | 111 ++++++++++++++++++ 4 files changed, 126 insertions(+), 83 deletions(-) create mode 100644 tensorflow/python/keras/distribute/sharded_variable_test.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 75792a00935..dac29b1c15e 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1135,10 +1135,6 @@ tf_py_test( name = "sharded_variable_test", size = "small", srcs = ["sharded_variable_test.py"], - tags = [ - # depend through //third_party/tensorflow/python:extra_py_tests_deps. - "ignore_for_dep=third_party.tensorflow.python.keras.engine.base_layer", - ], deps = [ ":sharded_variable", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py index 8b88d7b016e..a020a85de2d 100644 --- a/tensorflow/python/distribute/sharded_variable_test.py +++ b/tensorflow/python/distribute/sharded_variable_test.py @@ -30,11 +30,9 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec -from tensorflow.python.keras.engine import base_layer from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.saved_model import loader @@ -387,83 +385,6 @@ class ShardedVariableTest(test.TestCase): self.assertLen(model._checkpoint_dependencies, 1) self.assertEqual(model._checkpoint_dependencies[0].ref, model.w) - def test_keras_layer_setattr(self): - - class Layer(base_layer.Layer): - - def __init__(self): - super().__init__() - variables1 = [ - variables_lib.Variable([0]), - variables_lib.Variable([1]), - ] - variables2 = [ - variables_lib.Variable([2], trainable=False), - variables_lib.Variable([3], trainable=False), - ] - self.w = sharded_variable.ShardedVariable(variables1) - self.b = sharded_variable.ShardedVariable(variables2) - - layer = Layer() - - self.assertLen(layer.trainable_weights, 2) - self.assertEqual(layer.trainable_weights[0], [0]) - self.assertEqual(layer.trainable_weights[1], [1]) - self.assertLen(layer.non_trainable_weights, 2) - self.assertEqual(layer.non_trainable_weights[0], [2]) - self.assertEqual(layer.non_trainable_weights[1], [3]) - self.assertAllEqual(layer.weights, - layer.trainable_weights + layer.non_trainable_weights) - self.assertAllEqual(layer.trainable_weights, layer.trainable_variables) - self.assertAllEqual(layer.weights, layer.variables) - - checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies) - self.assertEqual(checkpoint_deps, set([layer.w, layer.b])) - - def test_keras_layer_add_weight(self): - - class Layer(base_layer.Layer): - - def __init__(self): - super().__init__() - self.w = self.add_weight( - shape=(2,), initializer=lambda shape, dtype: [0, 1], trainable=True) - self.b = self.add_weight( - shape=(2,), - initializer=lambda shape, dtype: [2, 3], - trainable=False) - - def sharded_variable_creator(next_creator, **kwargs): - v1_value = kwargs['initial_value']()[0:1] - v2_value = kwargs['initial_value']()[1:] - - kwargs['initial_value'] = v1_value - kwargs['shape'] = (1,) - v1 = next_creator(**kwargs) - - kwargs['initial_value'] = v2_value - kwargs['shape'] = (1,) - v2 = next_creator(**kwargs) - - return sharded_variable.ShardedVariable([v1, v2]) - - with variable_scope.variable_creator_scope(sharded_variable_creator): - layer = Layer() - - self.assertLen(layer.trainable_weights, 2) - self.assertEqual(layer.trainable_weights[0], [0]) - self.assertEqual(layer.trainable_weights[1], [1]) - self.assertLen(layer.non_trainable_weights, 2) - self.assertEqual(layer.non_trainable_weights[0], [2]) - self.assertEqual(layer.non_trainable_weights[1], [3]) - self.assertAllEqual(layer.weights, - layer.trainable_weights + layer.non_trainable_weights) - self.assertAllEqual(layer.trainable_weights, layer.trainable_variables) - self.assertAllEqual(layer.weights, layer.variables) - - checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies) - self.assertEqual(checkpoint_deps, set([layer.w, layer.b])) - def test_embedding_lookup(self): v = [ variables_lib.Variable([[1., 2.], [3., 4.]]), diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index fd77ad278b5..0d80f36917d 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -931,3 +931,18 @@ tf_py_test( "//tensorflow/python/compat:v2_compat", ], ) + +tf_py_test( + name = "sharded_variable_test", + size = "small", + srcs = ["sharded_variable_test.py"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/distribute:sharded_variable", + "//tensorflow/python/keras/engine:base_layer", + ], +) diff --git a/tensorflow/python/keras/distribute/sharded_variable_test.py b/tensorflow/python/keras/distribute/sharded_variable_test.py new file mode 100644 index 00000000000..381a40298a8 --- /dev/null +++ b/tensorflow/python/keras/distribute/sharded_variable_test.py @@ -0,0 +1,111 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ShardedVariable.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.compat import v2_compat +from tensorflow.python.distribute import sharded_variable +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import test + + +class ShardedVariableTest(test.TestCase): + + def test_keras_layer_setattr(self): + + class Layer(base_layer.Layer): + + def __init__(self): + super().__init__() + variables1 = [ + variables_lib.Variable([0]), + variables_lib.Variable([1]), + ] + variables2 = [ + variables_lib.Variable([2], trainable=False), + variables_lib.Variable([3], trainable=False), + ] + self.w = sharded_variable.ShardedVariable(variables1) + self.b = sharded_variable.ShardedVariable(variables2) + + layer = Layer() + + self.assertLen(layer.trainable_weights, 2) + self.assertEqual(layer.trainable_weights[0], [0]) + self.assertEqual(layer.trainable_weights[1], [1]) + self.assertLen(layer.non_trainable_weights, 2) + self.assertEqual(layer.non_trainable_weights[0], [2]) + self.assertEqual(layer.non_trainable_weights[1], [3]) + self.assertAllEqual(layer.weights, + layer.trainable_weights + layer.non_trainable_weights) + self.assertAllEqual(layer.trainable_weights, layer.trainable_variables) + self.assertAllEqual(layer.weights, layer.variables) + + checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies) + self.assertEqual(checkpoint_deps, set([layer.w, layer.b])) + + def test_keras_layer_add_weight(self): + + class Layer(base_layer.Layer): + + def __init__(self): + super().__init__() + self.w = self.add_weight( + shape=(2,), initializer=lambda shape, dtype: [0, 1], trainable=True) + self.b = self.add_weight( + shape=(2,), + initializer=lambda shape, dtype: [2, 3], + trainable=False) + + def sharded_variable_creator(next_creator, **kwargs): + v1_value = kwargs['initial_value']()[0:1] + v2_value = kwargs['initial_value']()[1:] + + kwargs['initial_value'] = v1_value + kwargs['shape'] = (1,) + v1 = next_creator(**kwargs) + + kwargs['initial_value'] = v2_value + kwargs['shape'] = (1,) + v2 = next_creator(**kwargs) + + return sharded_variable.ShardedVariable([v1, v2]) + + with variable_scope.variable_creator_scope(sharded_variable_creator): + layer = Layer() + + self.assertLen(layer.trainable_weights, 2) + self.assertEqual(layer.trainable_weights[0], [0]) + self.assertEqual(layer.trainable_weights[1], [1]) + self.assertLen(layer.non_trainable_weights, 2) + self.assertEqual(layer.non_trainable_weights[0], [2]) + self.assertEqual(layer.non_trainable_weights[1], [3]) + self.assertAllEqual(layer.weights, + layer.trainable_weights + layer.non_trainable_weights) + self.assertAllEqual(layer.trainable_weights, layer.trainable_variables) + self.assertAllEqual(layer.weights, layer.variables) + + checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies) + self.assertEqual(checkpoint_deps, set([layer.w, layer.b])) + + +if __name__ == '__main__': + v2_compat.enable_v2_behavior() + test.main()