diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index 4f10b4761e3..168d2507e34 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -662,9 +662,8 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) { } const TpuCompilationCacheKey key = CreateCompilationCacheKey( - function_.name(), metadata_.function_library_fingerprint(), - /*mlir_module=*/"", guaranteed_constants, dynamic_shapes, metadata_, - *mesh_state); + function_.name(), metadata_.function_library_fingerprint(), mlir_module_, + guaranteed_constants, dynamic_shapes, metadata_, *mesh_state); // Process-wide cache of TPU executables. TpuCompilationCacheInterface* cache; diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index f3f98fe50de..356fb3a7a9f 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -657,6 +657,21 @@ tpu_py_test( ], ) +tpu_py_test( + name = "tpu_strategy_compilation_test", + srcs = ["tpu_strategy_compilation_test.py"], + disable_experimental = True, + disable_mlir_bridge = False, + python_version = "PY3", + tags = ["no_oss"], + deps = [ + ":tpu_strategy", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/eager:remote", + "//tensorflow/python/eager:test", + ], +) + # Used only by estimator. py_library( name = "estimator_training", diff --git a/tensorflow/python/distribute/tpu_strategy_compilation_test.py b/tensorflow/python/distribute/tpu_strategy_compilation_test.py new file mode 100644 index 00000000000..ed61c063a4f --- /dev/null +++ b/tensorflow/python/distribute/tpu_strategy_compilation_test.py @@ -0,0 +1,87 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TPUStrategy in regards to compiling programs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.distribute import tpu_strategy as tpu_lib +from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver +from tensorflow.python.eager import def_function +from tensorflow.python.eager import remote +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import flags +from tensorflow.python.tpu import tpu_strategy_util + +FLAGS = flags.FLAGS +flags.DEFINE_string("tpu", "", "Name of TPU to connect to.") +flags.DEFINE_string("project", None, "Name of GCP project with TPU.") +flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.") + + +def get_tpu_cluster_resolver(): + resolver = tpu_cluster_resolver.TPUClusterResolver( + tpu=FLAGS.tpu, + zone=FLAGS.zone, + project=FLAGS.project, + ) + return resolver + + +def get_tpu_strategy(): + resolver = get_tpu_cluster_resolver() + remote.connect_to_cluster(resolver) + tpu_strategy_util.initialize_tpu_system(resolver) + strategy = tpu_lib.TPUStrategyV2(resolver) + return strategy + + +# TODO(b/158494076): Merge this test back into TPUStrategy tests +# (tpu_strategy_test) once MLIR bridge is enabled by default. +class TPUStrategyCompilationTest(test.TestCase): + + def test_functions_compile_same_signature(self): + """Tests compiling different functions with the same signature.""" + strategy = get_tpu_strategy() + + @def_function.function + def return_one(): + + def computation(): + return constant_op.constant(1) + + return strategy.run(computation) + + @def_function.function + def return_two(): + + def computation(): + return constant_op.constant(2) + + return strategy.run(computation) + + expected_result_ones = [1 for _ in range(0, strategy.num_replicas_in_sync)] + self.assertAllEqual(expected_result_ones, + strategy.experimental_local_results(return_one())) + + expected_result_twos = [2 for _ in range(0, strategy.num_replicas_in_sync)] + self.assertAllEqual(expected_result_twos, + strategy.experimental_local_results(return_two())) + + +if __name__ == "__main__": + test.main()