From 027c55b6676960b4c34020d3c130d73b377b11c5 Mon Sep 17 00:00:00 2001
From: Allen Lavoie <allenl@google.com>
Date: Mon, 31 Aug 2020 09:08:26 -0700
Subject: [PATCH] Parallel device: Remove the DeviceID op usage for now

It may be good to have later, but for now having an op registered is an unnecessary API surface. There's not a terribly strong case for the op yet, and having it as a dependency but not a core op makes it hard to depend on the parallel device.

Keeps a "device_id" property on the ParallelDevice in Python, which will behave identically to the previous version

PiperOrigin-RevId: 329313676
Change-Id: Iaf024d0d762061ba90e3568199334b53432001ce
---
 tensorflow/c/eager/parallel_device/BUILD      | 19 ----------
 .../eager/parallel_device/parallel_device.cc  |  7 ----
 .../parallel_device/parallel_device_ops.cc    | 26 --------------
 .../parallel_device_testlib.cc                | 26 --------------
 .../python/distribute/parallel_device/BUILD   | 23 ------------
 .../parallel_device/parallel_device.py        | 36 ++++++++++++++-----
 6 files changed, 28 insertions(+), 109 deletions(-)
 delete mode 100644 tensorflow/c/eager/parallel_device/parallel_device_ops.cc

diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD
index df5504adce2..3eec95294b3 100644
--- a/tensorflow/c/eager/parallel_device/BUILD
+++ b/tensorflow/c/eager/parallel_device/BUILD
@@ -103,7 +103,6 @@ cc_library(
     hdrs = ["parallel_device_testlib.h"],
     deps = [
         ":parallel_device",
-        ":parallel_device_ops",
         "//tensorflow/c:c_api",
         "//tensorflow/c:c_api_experimental",
         "//tensorflow/c/eager:c_api",
@@ -118,7 +117,6 @@ tf_cc_test(
     srcs = ["parallel_device_test.cc"],
     deps = [
         ":parallel_device",
-        ":parallel_device_ops",
         ":parallel_device_testlib",
         "//tensorflow/c:c_api",
         "//tensorflow/c:c_api_experimental",
@@ -138,7 +136,6 @@ tf_cc_test(
     args = ["--heap_check=local"],
     deps = [
         ":parallel_device",
-        ":parallel_device_ops",
         ":parallel_device_testlib",
         "//tensorflow/c:c_api",
         "//tensorflow/c:c_api_experimental",
@@ -150,19 +147,3 @@ tf_cc_test(
         "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
     ],
 )
-
-# Note: ParallelDevice-specific ops are experimental and not currently linked in
-# to TensorFlow by default, just used in a few tests.
-filegroup(
-    name = "parallel_device_ops_srcs",
-    srcs = ["parallel_device_ops.cc"],
-    visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
-)
-
-cc_library(
-    name = "parallel_device_ops",
-    srcs = [":parallel_device_ops_srcs"],
-    visibility = ["//tensorflow:internal"],
-    deps = ["//tensorflow/core:framework"],
-    alwayslink = 1,
-)
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc
index b9d7be7f8ea..41bde23448b 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device.cc
@@ -136,13 +136,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
     }
     result.emplace(std::move(outputs));
     return result;
-  } else if (operation_name == std::string("DeviceID")) {
-    std::vector<MaybeParallelTensorOwned> result_content;
-    result_content.reserve(1);
-    result_content.push_back(parallel_device.DeviceIDs(context, status));
-    if (TF_GetCode(status) != TF_OK) return result;
-    result.emplace(std::move(result_content));
-    return result;
   }
   std::vector<ParallelTensor*> parallel_inputs;
   std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc
deleted file mode 100644
index 1decffca047..00000000000
--- a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc
+++ /dev/null
@@ -1,26 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-// TODO(allenl): Figure out if we need this op, and if so whether we should move
-// it to core TF. Right now the eager C API does some checking of op
-// registrations before calling into custom devices, but we may be able to avoid
-// that.
-REGISTER_OP("DeviceID")
-    .Output("device_id: int64")
-    .SetIsStateful()
-    .SetShapeFn(tensorflow::shape_inference::ScalarShape);
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc
index 828dcbae093..67bc596b180 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc
@@ -279,30 +279,4 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
         TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
     ASSERT_EQ(underlying_devices[1], second_device);
   }
-  // Compute the device ID twice and verify the result
-  for (int i = 0; i < 2; ++i) {
-    std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
-        TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
-    ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-    TFE_OpSetDevice(op.get(), device_name, status.get());
-    ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-
-    TFE_TensorHandle* result_handle;
-    int num_retvals = 1;
-    TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
-    ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-    std::array<TensorHandlePtr, 2> components;
-    ExtractPerDeviceValues(context, result_handle, &components, status.get());
-    TFE_DeleteTensorHandle(result_handle);
-    ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-
-    ExpectScalarEq<int32_t>(components[0].get(), 0);
-    ExpectScalarEq<int32_t>(components[1].get(), 1);
-    std::string first_device =
-        TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
-    ASSERT_EQ(underlying_devices[0], first_device);
-    std::string second_device =
-        TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
-    ASSERT_EQ(underlying_devices[1], second_device);
-  }
 }
diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD
index e0f76fa6652..3d08f5b90e3 100644
--- a/tensorflow/python/distribute/parallel_device/BUILD
+++ b/tensorflow/python/distribute/parallel_device/BUILD
@@ -1,6 +1,3 @@
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
-
 package(
     default_visibility = ["//tensorflow:internal"],
     licenses = ["notice"],  # Apache 2.0
@@ -17,7 +14,6 @@ py_library(
     srcs = ["parallel_device.py"],
     srcs_version = "PY2AND3",
     deps = [
-        ":parallel_device_ops",
         ":saving",
         "//tensorflow/python:_pywrap_parallel_device",
         "//tensorflow/python/distribute:device_util",
@@ -31,25 +27,6 @@ py_library(
     deps = ["//tensorflow/python:framework_ops"],
 )
 
-tf_gen_op_wrapper_py(
-    name = "parallel_device_ops_py",
-    out = "gen_parallel_device_ops.py",
-    deps = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
-)
-
-tf_custom_op_library(
-    name = "_parallel_device_ops.so",
-    srcs = ["//tensorflow/c/eager/parallel_device:parallel_device_ops_srcs"],
-)
-
-tf_custom_op_py_library(
-    name = "parallel_device_ops",
-    dso = [":_parallel_device_ops.so"],
-    kernels = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
-    visibility = ["//tensorflow:internal"],
-    deps = [":parallel_device_ops_py"],
-)
-
 py_test(
     name = "parallel_device_test",
     srcs = ["parallel_device_test.py"],
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py
index 94d43561a30..30381e2a95d 100644
--- a/tensorflow/python/distribute/parallel_device/parallel_device.py
+++ b/tensorflow/python/distribute/parallel_device/parallel_device.py
@@ -22,23 +22,23 @@ import threading
 
 from tensorflow.python import _pywrap_parallel_device
 from tensorflow.python.distribute import device_util
-from tensorflow.python.distribute.parallel_device import gen_parallel_device_ops
 from tensorflow.python.distribute.parallel_device import saving
 from tensorflow.python.eager import context
-from tensorflow.python.framework import load_library
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
-from tensorflow.python.platform import resource_loader
+from tensorflow.python.ops import array_ops
 from tensorflow.python.tpu.ops import tpu_ops
 
-load_library.load_op_library(
-    resource_loader.get_path_to_datafile("_parallel_device_ops.so"))
-
 _next_device_number = 0
 _next_device_number_lock = threading.Lock()
 
 
 # TODO(allenl): Expand this docstring once things like getting components on and
 # off the device are stable.
+#
+# TODO(allenl): Make multi-client work; we need an offset for device IDs, and an
+# indication of how many other devices there are total for collectives which
+# don't have a number of participants hard-coded in their attributes.
 class ParallelDevice(object):
   """A device which executes operations in parallel."""
 
@@ -64,8 +64,7 @@ class ParallelDevice(object):
     device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
         self._name, self.components)
     context.register_custom_device(device, self._name, device_info)
-    with ops.device(self._name):
-      self._device_ids = gen_parallel_device_ops.device_id()
+    self._device_ids = None
     self._device_scope = None
     self._saving_scope = None
 
@@ -106,6 +105,27 @@ class ParallelDevice(object):
     Returns:
       A parallel tensor containing 0 on the first device, 1 on the second, etc.
     """
+    if self._device_ids is None:
+      # device_ids may be called from inside a tf.function, in which case the
+      # function captures the eager tensor. We can't pack tensors in a function
+      # at the moment, and even if we could we don't want to hold on to a
+      # symbolic tensor, so we need to init_scope out of the function
+      # temporarily.
+      with ops.init_scope():
+        # TODO(allenl): Functions which capture eager device ID tensors won't be
+        # saveable in SavedModels. Ideally we'd run a DeviceID op every time
+        # device IDs are required, with functions using the op in their bodies
+        # but not hard-coding a fixed number of devices (so they can be re-used
+        # with a different replica count).
+        device_ids_list = []
+        for index, device in enumerate(self.components):
+          with ops.device(device):
+            # The identity op ensures each device ID tensor is placed on its
+            # device.
+            device_ids_list.append(
+                array_ops.identity(constant_op.constant(index)))
+        self._device_ids = self.pack(device_ids_list)
+
     return self._device_ids
 
   def _assert_eager(self):