# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for remote execution."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random

from absl.testing import parameterized
import numpy as np
import six

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import server_lib
from tensorflow.python.training.server_lib import ClusterSpec


class SingleWorkerTest(test.TestCase, parameterized.TestCase):

  def setUp(self):
    super(SingleWorkerTest, self).setUp()

    workers, _ = test_util.create_local_cluster(1, 0)
    remote.connect_to_remote_host(workers[0].target)

  def tearDown(self):
    super(SingleWorkerTest, self).tearDown()

    # Clear the current device scope to avoid polluting other test cases.
    ops.device(None).__enter__()
    # Reset the context to avoid polluting other test cases.
    context._reset_context()

  @test_util.eager_lazy_remote_copy_on_and_off
  def testMultiDeviceFunctionBasic(self):

    @def_function.function
    def basic(i):
      with ops.device('/job:localhost/replica:0/task:0/cpu:0'):
        a = constant_op.constant([2]) + i
      with ops.device('/job:worker/replica:0/task:0/cpu:0'):
        b = constant_op.constant([1])

      return a + b

    self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
    self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])

  @test_util.eager_lazy_remote_copy_on_and_off
  def testMultiDeviceFunctionVariable(self):
    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
      variable_b = variables.Variable(1)

    @def_function.function
    def with_variable(i):
      return i + variable_b

    self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])

  @test_util.eager_lazy_remote_copy_on_and_off
  def testMultiDeviceFunctionRemoteOutput(self):
    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
      variable_b = variables.Variable(1)

    @def_function.function
    def remote_output(i):
      with ops.device('/job:worker/replica:0/task:0/cpu:0'):
        c = variable_b + 1
      return c, i + variable_b

    self.assertAllEqual(
        remote_output(constant_op.constant([1]))[0].numpy(), 2)

  def testMultiDeviceFunctionAmbiguousDevice(self):

    @def_function.function
    def ambiguous_device(i):
      with ops.device('cpu:0'):
        return i + constant_op.constant([2])

    with self.assertRaises(errors.InvalidArgumentError) as cm:
      with ops.device('/job:worker/replica:0/task:0/cpu:0'):
        ambiguous_device(constant_op.constant([2])).numpy()

    self.assertIn('the output node must match exactly one device',
                  cm.exception.message)

  def testStreaming(self):
    """A mini stress test for streaming - issuing many RPCs back to back."""
    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
      x = array_ops.ones([2, 2])
      y = array_ops.zeros([2, 2])
      num_iters = 200
      for _ in range(num_iters):
        y = x + y
        # Ask for y's shape after every 10 additions on average.
        # This exercises waiting for remote shape logic in TensorHandle.
        if random.randint(1, 10) == 1:
          _ = y.shape
    np.testing.assert_array_equal(
        [[num_iters, num_iters], [num_iters, num_iters]], y.numpy())

  def testShapeError_OpByOp(self):
    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
      x = array_ops.ones([2, 3])
      y = array_ops.zeros([2, 2])
      with self.assertRaises(errors.InvalidArgumentError) as cm:
        math_ops.matmul(x, y)

    self.assertIn('Dimensions must be equal', cm.exception.message)

  @test_util.eager_lazy_remote_copy_on_and_off
  def testShapeError_Function(self):

    @def_function.function
    def matmul_func(x, y):
      return math_ops.matmul(x, y)

    x = array_ops.ones([2, 3])
    y = array_ops.zeros([2, 2])

    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
      with self.assertRaises(ValueError) as cm:
        matmul_func(x, y)

    if six.PY2:
      self.assertIn('Dimensions must be equal', cm.exception.message)
    else:
      self.assertIn('Dimensions must be equal', cm.exception.args[0])

  def testClientVarible(self):
    var = variables.Variable(initial_value=0)

    @def_function.function
    def func():
      with ops.device('/job:localhost/task:0'):
        read = var.read_value()
      return read + 1

    with ops.device('/job:worker/task:0'):
      self.assertAllEqual(func(), 1)

  @test_util.eager_lazy_remote_copy_on_and_off
  def testRemoteCall(self):

    @def_function.function(
        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
    def _remote_fn(x):
      return constant_op.constant(1) + x

    remote_fn = _remote_fn.get_concrete_function()

    @def_function.function
    def func(x):
      return functional_ops.remote_call(
          args=[x],
          Tout=[dtypes.int32],
          f=remote_fn,
          target='/job:worker/task:0')

    with ops.device('/job:localhost/task:0'):
      self.assertAllEqual(func(constant_op.constant(1)), [2])


class RemoteAsyncTest(test.TestCase):

  def setUp(self):
    super(RemoteAsyncTest, self).setUp()

    workers, _ = test_util.create_local_cluster(1, 0)
    remote.connect_to_remote_host(workers[0].target)

  def tearDown(self):
    super(RemoteAsyncTest, self).tearDown()

    # Reset the context to avoid polluting other test cases.
    context._reset_context()

  def test_out_of_range_with_while_loop(self):

    with ops.device('/job:worker/task:0'):
      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
      dataset = dataset.batch(1, drop_remainder=False)
      iterator = iter(dataset)
      v = variables.Variable(1.0)

    @def_function.function
    def train_step(iterator):
      i = next(iterator)
      v.assign_add(math_ops.reduce_mean(i))

    while True:
      try:
        with ops.device('/job:worker/task:0'):
          train_step(iterator)
      except (errors.OutOfRangeError, errors.InternalError):
        context.async_clear_error()
        break

    self.assertAllEqual(v.numpy(), 4.0)

  def test_out_of_range_with_for_loop(self):

    with ops.device('/job:worker/task:0'):
      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
      dataset = dataset.batch(1, drop_remainder=False)
      iterator = iter(dataset)
      v = variables.Variable(1.0)

    @def_function.function
    def train_step(iterator):
      i = next(iterator)
      v.assign_add(math_ops.reduce_mean(i))

    num_steps = 3
    for i in range(num_steps):
      try:
        with ops.device('/job:worker/task:0'):
          train_step(iterator)
        if i == num_steps - 1:
          context.async_wait()
      except errors.OutOfRangeError:
        context.async_clear_error()
        break

    self.assertAllEqual(v.numpy(), 4.0)

  def test_out_of_range_with_async_scope(self):

    with ops.device('/job:worker/task:0'):
      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
      dataset = dataset.batch(1, drop_remainder=False)
      iterator = iter(dataset)
      v = variables.Variable(1.0)

    @def_function.function
    def train_step(iterator):
      i = next(iterator)
      v.assign_add(math_ops.reduce_mean(i))

    num_steps = 3
    try:
      with context.async_scope():
        for _ in range(num_steps):
          with ops.device('/job:worker/task:0'):
            train_step(iterator)
    except errors.OutOfRangeError:
      context.async_clear_error()

    self.assertAllEqual(v.numpy(), 4.0)


class MultiWorkersTest(test.TestCase, parameterized.TestCase):

  def setUp(self):
    super(MultiWorkersTest, self).setUp()

    workers, _ = test_util.create_local_cluster(3, 0)
    remote.connect_to_remote_host(
        [workers[0].target, workers[1].target, workers[2].target])

  def tearDown(self):
    super(MultiWorkersTest, self).tearDown()

    # Clear the current device scope to avoid polluting other test cases.
    ops.device(None).__enter__()
    # Reset the context to avoid polluting other test cases.
    context._reset_context()

  @test_util.eager_lazy_remote_copy_on_and_off
  def testReturnRemoteArgument(self):

    @def_function.function
    def local_func(i):
      return i

    with ops.device('/job:worker/replica:0/task:0'):
      x = constant_op.constant([2, 1])

    with ops.device('/job:worker/replica:0/task:1'):
      self.assertAllEqual(local_func(x), [2, 1])

  @test_util.eager_lazy_remote_copy_on_and_off
  def testMultiDeviceFunctionOnLocalDevice(self):
    with ops.device('/job:worker/replica:0/task:1'):
      variable_b = variables.Variable(1.0)

    @def_function.function
    def remote_function(i):
      with ops.device('/job:worker/replica:0/task:0'):
        a = i + variable_b
      c = a + 1.0
      return c

    self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

  def testMultiDeviceFunctionWithPackedVariable(self):
    with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
      var0 = resource_variable_ops.ResourceVariable(1.0)
    with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
      var1 = resource_variable_ops.ResourceVariable(2.0)

    packed_var = ops.pack_eager_tensors([var0.handle, var1.handle])
    self.assertEqual(packed_var.device,
                     '/job:localhost/replica:0/task:0/device:COMPOSITE:0')
    self.assertEqual(packed_var.backing_device,
                     '/job:localhost/replica:0/task:0/device:COMPOSITE:0')

    @def_function.function
    def add_variables():
      with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
        read0 = resource_variable_ops.read_variable_op(
            packed_var, dtype=dtypes.float32)
      with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
        read1 = resource_variable_ops.read_variable_op(
            packed_var, dtype=dtypes.float32)

      return read0 + read1

    # Run the function on a remote device
    with ops.device('/job:worker/replica:0/task:0'):
      self.assertAllEqual(add_variables().numpy(), 3.0)

    # Run the function on a local worker
    self.assertAllEqual(add_variables().numpy(), 3.0)

  @test_util.eager_lazy_remote_copy_on_and_off
  def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
    with ops.device('/job:worker/replica:0/task:1'):
      variable_b = variables.Variable([1.0])

    @def_function.function
    def remote_function(i):
      x = array_ops.ones([1000, 1000])
      for _ in range(1, 1000):
        x = x * x
      variable_b.assign_add(i)
      a = 1.0 + variable_b
      return a

    @def_function.function
    def remote_function2(i):
      variable_b.assign_add(i)
      a = 1.0 + variable_b
      return a

    # Runs first function:
    # - on remote device
    # - needs remote input
    # - is side impacting
    # - runs much slower
    with ops.device('/job:worker/replica:0/task:0'):
      remote_function(constant_op.constant([2.0]))

    # Runs second function:
    # - on remote device
    # - is side impacting
    # There should be a sync point here and the next function will be executed
    # only after the first function has completed.
    with ops.device('/job:worker/replica:0/task:2'):
      self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0])

  @test_util.eager_lazy_remote_copy_on_and_off
  def testMultiDeviceFunctionOnRemoteDevice(self):
    with ops.device('/job:worker/replica:0/task:1'):
      variable_b = variables.Variable(1.0)

    @def_function.function
    def remote_function(i):
      with ops.device('/job:worker/replica:0/task:0'):
        a = i + variable_b
      c = a + 1.0
      return c

    context.context().mirroring_policy = context.MIRRORING_NONE

    with ops.device('/job:worker/replica:0/task:0'):
      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

    if test_util.is_gpu_available():
      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

    context.context().mirroring_policy = context.MIRRORING_ALL

    with ops.device('/job:worker/replica:0/task:0'):
      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

    if test_util.is_gpu_available():
      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

  @test_util.eager_lazy_remote_copy_on_and_off
  def testMultiDeviceWhileLoopOnRemoteDevice(self):
    with ops.device('/job:worker/replica:0/task:1'):
      variable_b = variables.Variable(1.0)

    @def_function.function
    def remote_function(i):

      def body(i, _):
        with ops.device('/job:worker/replica:0/task:0'):
          a = i + variable_b
        return a + 1.0, 1

      return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0]

    context.context().mirroring_policy = context.MIRRORING_NONE

    with ops.device('/job:worker/replica:0/task:0'):
      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

    if test_util.is_gpu_available():
      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

    context.context().mirroring_policy = context.MIRRORING_ALL

    with ops.device('/job:worker/replica:0/task:0'):
      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

    if test_util.is_gpu_available():
      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])

  @test_util.eager_lazy_remote_copy_on_and_off
  def testSimpleParameterServer(self):

    with ops.device('/job:worker/task:2/device:CPU:0'):
      v1 = variables.Variable(initial_value=0)
      v2 = variables.Variable(initial_value=10)

    @def_function.function
    def worker_fn():
      v1.assign_add(1)
      v2.assign_sub(2)
      return v1.read_value() + v2.read_value()

    with ops.device('/job:worker/task:0/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 9)

    with ops.device('/job:worker/task:1/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 8)


_GRPC_PREFIX = 'grpc://'


class MultiJobsTest(test.TestCase, parameterized.TestCase):

  def setUp(self):
    super(MultiJobsTest, self).setUp()

    workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2)
    cluster = {
        'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers],
        'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps],
    }
    self._cluster = server_lib.ClusterSpec(cluster)
    self._cluster_resolver = SimpleClusterResolver(
        cluster_spec=self._cluster, master=ps[0].target)

  def tearDown(self):
    super(MultiJobsTest, self).tearDown()

    # Clear the current device scope to avoid polluting other test cases.
    ops.device(None).__enter__()
    # Reset the context to avoid polluting other test cases.
    context._reset_context()

  @test_util.eager_lazy_remote_copy_on_and_off
  def testSimpleParameterServer(self):
    remote.connect_to_cluster(self._cluster)

    with ops.device('/job:my_ps/task:0/device:CPU:0'):
      v1 = variables.Variable(initial_value=0)
      v2 = variables.Variable(initial_value=10)

    @def_function.function
    def worker_fn():
      v1.assign_add(1)
      v2.assign_sub(2)
      return v1.read_value() + v2.read_value()

    with ops.device('/job:my_worker/task:0/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 9)

    with ops.device('/job:my_worker/task:1/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 8)

  # TODO(b/152224115): Re-enable this test.
  @test_util.eager_lazy_remote_copy_on_and_off
  def DISABLED_testSimpleParameterServerWithDeviceFilters(self):
    cluster_device_filters = server_lib.ClusterDeviceFilters()
    for i in range(2):
      cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps'])
      cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker'])
    remote.connect_to_cluster(
        self._cluster, cluster_device_filters=cluster_device_filters)

    with ops.device('/job:my_ps/task:0/device:CPU:0'):
      v1 = variables.Variable(initial_value=0)
    with ops.device('/job:my_ps/task:1/device:CPU:0'):
      v2 = variables.Variable(initial_value=10)

    @def_function.function
    def worker_fn():
      v1.assign_add(1)
      v2.assign_sub(2)
      return v1.read_value() + v2.read_value()

    with ops.device('/job:my_worker/task:0/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 9)
    with ops.device('/job:my_worker/task:1/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 8)

    # The following remote call would fail because the ps nodes cannot see each
    # other due to the device filters.
    with self.assertRaises(errors.InvalidArgumentError) as cm:
      with ops.device('/job:my_ps/task:0/device:CPU:0'):
        worker_fn().numpy()
    self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device',
                  cm.exception.message)

    with self.assertRaises(errors.InvalidArgumentError) as cm:
      with ops.device('/job:my_ps/task:1/device:CPU:0'):
        worker_fn().numpy()
    self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device',
                  cm.exception.message)

    with ops.device('/job:my_worker/task:0/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 7)
    with ops.device('/job:my_worker/task:1/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 6)
    # Explicitly delete variables to avoid triggering errors when being GC'ed in
    # subsequent tests.
    del v1, v2

  @test_util.eager_lazy_remote_copy_on_and_off
  def testConnectWithClusterResolver(self):
    remote.connect_to_cluster(self._cluster_resolver)

    v1 = variables.Variable(initial_value=0)
    v2 = variables.Variable(initial_value=10)

    @def_function.function
    def worker_fn():
      v1.assign_add(1)
      v2.assign_sub(2)
      return v1.read_value() + v2.read_value()

    with ops.device('/job:my_worker/task:0/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 9)

    with ops.device('/job:my_worker/task:1/device:CPU:0'):
      self.assertAllEqual(worker_fn(), 8)

  @test_util.eager_lazy_remote_copy_on_and_off
  def testConnectToClusterTwiceOk(self):
    remote.connect_to_cluster(self._cluster_resolver)
    remote.connect_to_cluster(self._cluster_resolver)

  @test_util.eager_lazy_remote_copy_on_and_off
  def testConnectToClusterOnMismatchedDevice(self):
    remote.connect_to_cluster(self._cluster_resolver)

    # enter into another device scope.
    ops.device('/job:my_worker/task:0/device:CPU:0').__enter__()

    with self.assertRaises(ValueError):
      remote.connect_to_cluster(self._cluster_resolver)

  @test_util.eager_lazy_remote_copy_on_and_off
  def testConnectToClusterWithLocalMaster(self):
    local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
    remote.connect_to_cluster(local_resolver)

  @test_util.eager_lazy_remote_copy_on_and_off
  def testConnectToClusterInGraphModeWillFail(self):
    ops.disable_eager_execution()
    with self.assertRaises(ValueError):
      remote.connect_to_cluster(self._cluster_resolver)
    ops.enable_eager_execution()


def _strip_prefix(s, prefix):
  return s[len(prefix):] if s.startswith(prefix) else s


if __name__ == '__main__':
  test.main()