diff --git a/README.md b/README.md
index 05ddb90fabc..27032043e07 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,10 @@
+[](https://badge.fury.io/py/tensorflow)
+[](https://badge.fury.io/py/tensorflow)
+
+
**`Documentation`** |
------------------- |
[](https://www.tensorflow.org/api_docs/) |
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 4326b723f74..d49f679083e 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -223,6 +223,7 @@ tf_cuda_cc_test(
":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@@ -371,6 +372,22 @@ tf_cuda_cc_test(
],
)
+cc_library(
+ name = "custom_device_testutil",
+ testonly = True,
+ srcs = ["custom_device_testutil.cc"],
+ hdrs = ["custom_device_testutil.h"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":c_api",
+ ":c_api_experimental",
+ ":c_api_test_util",
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_test(
name = "custom_device_test",
size = "small",
@@ -381,6 +398,7 @@ tf_cc_test(
":c_api",
":c_api_experimental",
":c_api_test_util",
+ ":custom_device_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
@@ -448,6 +466,7 @@ filegroup(
],
exclude = [
"c_api_experimental.cc",
+ "*c_api_tfrt*",
"*test*",
"*dlpack*",
],
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 51ee82e55aa..b34d1026e08 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -21,6 +21,10 @@ limitations under the License.
#include
#include
+// clang-format off
+#include "tensorflow/core/platform/platform.h"
+// clang-format on
+
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
@@ -31,12 +35,14 @@ limitations under the License.
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_tensor_internal.h"
+#ifdef PLATFORM_GOOGLE
+#include "tensorflow/c/eager/c_api_tfrt.h"
+#endif
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
@@ -676,6 +682,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
+ if (opts->use_tfrt) {
+#ifdef PLATFORM_GOOGLE
+ status->status = tensorflow::Status::OK();
+ return new TFE_Context{new tfrt::ContextInterface()};
+#else
+ status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
+ return nullptr;
+#endif
+ }
std::vector> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
@@ -1669,6 +1684,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
};
} // namespace
+extern "C" {
+
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status) {
@@ -1679,3 +1696,5 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device));
}
+
+} // extern "C"
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index 0037f2e81c8..dc1f9eaade3 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -515,9 +515,11 @@ typedef struct TFE_CustomDevice {
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
-void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
- const char* device_name, void* device_info,
- TF_Status* status);
+TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
+ TFE_CustomDevice device,
+ const char* device_name,
+ void* device_info,
+ TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 55f1941ce89..e61cf7ef040 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -19,6 +19,10 @@ limitations under the License.
#include
+// clang-format off
+#include "tensorflow/core/platform/platform.h"
+// clang-format on
+
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -584,9 +588,10 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContext(ctx);
}
-void ExecuteAdd(bool async, bool forward_input) {
+void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetTfrt(opts, tfrt);
TFE_ContextOptionsSetAsync(opts, static_cast(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@@ -651,7 +656,6 @@ void ExecuteAdd(bool async, bool forward_input) {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval);
- TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float result[100 * 100] = {0};
@@ -661,12 +665,42 @@ void ExecuteAdd(bool async, bool forward_input) {
for (int i = 0; i < 100 * 100; ++i) {
EXPECT_EQ(2.0f, result[i]);
}
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
-TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
-TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
-TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
-TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
+TEST(CAPI, ExecuteAdd) {
+ ExecuteAdd(
+ /*async=*/false,
+ /*forward_input*/ false,
+ /*tfrt*/ false);
+}
+TEST(CAPI, ExecuteAddAsync) {
+ ExecuteAdd(
+ /*async=*/true,
+ /*forward_input*/ false,
+ /*tfrt*/ false);
+}
+TEST(CAPI, ExecuteAddForward) {
+ ExecuteAdd(
+ /*async=*/false,
+ /*forward_input*/ true,
+ /*tfrt*/ false);
+}
+TEST(CAPI, ExecuteAddForwardAsync) {
+ ExecuteAdd(
+ /*async=*/true,
+ /*forward_input*/ true,
+ /*tfrt*/ false);
+}
+#ifdef PLATFORM_GOOGLE
+// TODO(b/153349425): Add add forwarding tests for TFRT
+TEST(CAPI, ExecuteAddTfrt) {
+ ExecuteAdd(
+ /*async=*/false,
+ /*forward_input*/ false,
+ /*tfrt*/ true);
+}
+#endif
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc
index 8f13c1e5151..1ec9e9bd99a 100644
--- a/tensorflow/c/eager/custom_device_test.cc
+++ b/tensorflow/c/eager/custom_device_test.cc
@@ -20,129 +20,11 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
+#include "tensorflow/c/eager/custom_device_testutil.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
-namespace {
-
-struct LoggingDevice {
- tensorflow::string device_name;
- tensorflow::string underlying_device;
- // Set to true whenever a TensorHandle is copied onto the device
- bool* arrived_flag;
- // Set to true whenever an operation is executed
- bool* executed_flag;
-};
-
-struct LoggedTensor {
- TFE_TensorHandle* tensor;
- LoggedTensor() = delete;
- explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
- ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
-};
-
-void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
- delete reinterpret_cast(data);
-}
-
-TFE_TensorHandle* MakeLoggedTensorHandle(
- TFE_Context* context, const tensorflow::string& logging_device_name,
- std::unique_ptr t, TF_Status* status) {
- std::vector shape(TFE_TensorHandleNumDims(t->tensor, status));
- if (TF_GetCode(status) != TF_OK) return nullptr;
- for (int i = 0; i < shape.size(); ++i) {
- shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- }
- auto dtype = TFE_TensorHandleDataType(t->tensor);
- return TFE_NewTensorHandleFromDeviceMemory(
- context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
- t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
-}
-
-TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
- TFE_TensorHandle* tensor,
- TF_Status* status, void* device_info) {
- LoggingDevice* dev = reinterpret_cast(device_info);
- TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
- tensor, context, dev->underlying_device.c_str(), status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- auto dst = std::make_unique(t);
- *(dev->arrived_flag) = true;
- return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
- status);
-}
-
-TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
- TFE_TensorHandle* tensor,
- const char* target_device_name,
- TF_Status* status,
- void* device_info) {
- TF_SetStatus(status, TF_INTERNAL,
- "Trying to copy a tensor out of a logging device.");
- return nullptr;
-}
-
-void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
- TFE_TensorHandle** inputs, const char* operation_name,
- const TFE_OpAttrs* attributes, int* num_outputs,
- TFE_TensorHandle** outputs, TF_Status* s,
- void* device_info) {
- LoggingDevice* dev = reinterpret_cast(device_info);
- TFE_Op* op(TFE_NewOp(context, operation_name, s));
- if (TF_GetCode(s) != TF_OK) return;
- TFE_OpAddAttrs(op, attributes);
- TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
- for (int j = 0; j < num_inputs; ++j) {
- TFE_TensorHandle* input = inputs[j];
- const char* input_device = TFE_TensorHandleDeviceName(input, s);
- if (TF_GetCode(s) != TF_OK) return;
- if (dev->device_name == input_device) {
- LoggedTensor* t = reinterpret_cast(
- TFE_TensorHandleDevicePointer(input, s));
- if (TF_GetCode(s) != TF_OK) return;
- TFE_OpAddInput(op, t->tensor, s);
- } else {
- TFE_OpAddInput(op, input, s);
- }
- if (TF_GetCode(s) != TF_OK) return;
- }
- std::vector op_outputs(*num_outputs);
- TFE_Execute(op, op_outputs.data(), num_outputs, s);
- TFE_DeleteOp(op);
- if (TF_GetCode(s) != TF_OK) return;
- std::vector unwrapped_outputs;
- for (auto* handle : op_outputs) {
- unwrapped_outputs.push_back(handle);
- }
- for (int i = 0; i < *num_outputs; ++i) {
- auto logged_tensor = std::make_unique(unwrapped_outputs[i]);
- outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
- std::move(logged_tensor), s);
- }
- *(dev->executed_flag) = true;
-}
-
-void DeleteLoggingDevice(void* device_info) {
- delete reinterpret_cast(device_info);
-}
-
-void RegisterLoggingDevice(TFE_Context* context, const char* name,
- bool* arrived_flag, bool* executed_flag,
- TF_Status* status) {
- TFE_CustomDevice custom_device;
- custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
- custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
- custom_device.delete_device = &DeleteLoggingDevice;
- custom_device.execute = &LoggingDeviceExecute;
- LoggingDevice* device = new LoggingDevice;
- device->arrived_flag = arrived_flag;
- device->executed_flag = executed_flag;
- device->device_name = name;
- device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
- TFE_RegisterCustomDevice(context, custom_device, name, device, status);
-}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr status(
@@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
- reinterpret_cast(
- TFE_TensorHandleDevicePointer(var_value, status.get()))
- ->tensor;
+ UnpackTensorHandle(var_value, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
@@ -394,5 +274,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
-
-} // namespace
diff --git a/tensorflow/c/eager/custom_device_testutil.cc b/tensorflow/c/eager/custom_device_testutil.cc
new file mode 100644
index 00000000000..28de3665653
--- /dev/null
+++ b/tensorflow/c/eager/custom_device_testutil.cc
@@ -0,0 +1,172 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// A simple logging device to test custom device registration.
+#include
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/eager/c_api_test_util.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace {
+
+struct LoggingDevice {
+ tensorflow::string device_name;
+ tensorflow::string underlying_device;
+ // Set to true whenever a TensorHandle is copied onto the device
+ bool* arrived_flag;
+ // Set to true whenever an operation is executed
+ bool* executed_flag;
+};
+
+struct LoggedTensor {
+ TFE_TensorHandle* tensor;
+ LoggedTensor() = delete;
+ explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
+ ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
+};
+
+void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
+ delete reinterpret_cast(data);
+}
+
+TFE_TensorHandle* MakeLoggedTensorHandle(
+ TFE_Context* context, const tensorflow::string& logging_device_name,
+ std::unique_ptr t, TF_Status* status) {
+ std::vector shape(TFE_TensorHandleNumDims(t->tensor, status));
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ for (int i = 0; i < shape.size(); ++i) {
+ shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ }
+ auto dtype = TFE_TensorHandleDataType(t->tensor);
+ return TFE_NewTensorHandleFromDeviceMemory(
+ context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
+ t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
+}
+
+TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
+ TFE_TensorHandle* tensor,
+ TF_Status* status, void* device_info) {
+ LoggingDevice* dev = reinterpret_cast(device_info);
+ TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
+ tensor, context, dev->underlying_device.c_str(), status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ auto dst = std::make_unique(t);
+ *(dev->arrived_flag) = true;
+ return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
+ status);
+}
+
+TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
+ TFE_TensorHandle* tensor,
+ const char* target_device_name,
+ TF_Status* status,
+ void* device_info) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "Trying to copy a tensor out of a logging device.");
+ return nullptr;
+}
+
+void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
+ TFE_TensorHandle** inputs, const char* operation_name,
+ const TFE_OpAttrs* attributes, int* num_outputs,
+ TFE_TensorHandle** outputs, TF_Status* s,
+ void* device_info) {
+ LoggingDevice* dev = reinterpret_cast(device_info);
+ TFE_Op* op(TFE_NewOp(context, operation_name, s));
+ if (TF_GetCode(s) != TF_OK) return;
+ TFE_OpAddAttrs(op, attributes);
+ TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
+ for (int j = 0; j < num_inputs; ++j) {
+ TFE_TensorHandle* input = inputs[j];
+ const char* input_device = TFE_TensorHandleDeviceName(input, s);
+ if (TF_GetCode(s) != TF_OK) return;
+ if (dev->device_name == input_device) {
+ LoggedTensor* t = reinterpret_cast(
+ TFE_TensorHandleDevicePointer(input, s));
+ if (TF_GetCode(s) != TF_OK) return;
+ TFE_OpAddInput(op, t->tensor, s);
+ } else {
+ TFE_OpAddInput(op, input, s);
+ }
+ if (TF_GetCode(s) != TF_OK) return;
+ }
+ std::vector op_outputs(*num_outputs);
+ TFE_Execute(op, op_outputs.data(), num_outputs, s);
+ TFE_DeleteOp(op);
+ if (TF_GetCode(s) != TF_OK) return;
+ std::vector unwrapped_outputs;
+ for (auto* handle : op_outputs) {
+ unwrapped_outputs.push_back(handle);
+ }
+ for (int i = 0; i < *num_outputs; ++i) {
+ auto logged_tensor = std::make_unique(unwrapped_outputs[i]);
+ outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
+ std::move(logged_tensor), s);
+ }
+ *(dev->executed_flag) = true;
+}
+
+void DeleteLoggingDevice(void* device_info) {
+ delete reinterpret_cast(device_info);
+}
+
+} // namespace
+
+void RegisterLoggingDevice(TFE_Context* context, const char* name,
+ bool* arrived_flag, bool* executed_flag,
+ TF_Status* status) {
+ TFE_CustomDevice custom_device;
+ custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
+ custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
+ custom_device.delete_device = &DeleteLoggingDevice;
+ custom_device.execute = &LoggingDeviceExecute;
+ LoggingDevice* device = new LoggingDevice;
+ device->arrived_flag = arrived_flag;
+ device->executed_flag = executed_flag;
+ device->device_name = name;
+ device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+ TFE_RegisterCustomDevice(context, custom_device, name, device, status);
+}
+
+TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
+ TF_Status* status) {
+ return reinterpret_cast(
+ TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
+ ->tensor;
+}
+
+void AllocateLoggingDevice(const char* name, bool* arrived_flag,
+ bool* executed_flag, TFE_CustomDevice** device,
+ void** device_info) {
+ TFE_CustomDevice* custom_device = new TFE_CustomDevice;
+ custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
+ custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
+ custom_device->delete_device = &DeleteLoggingDevice;
+ custom_device->execute = &LoggingDeviceExecute;
+ *device = custom_device;
+ LoggingDevice* logging_device = new LoggingDevice;
+ logging_device->arrived_flag = arrived_flag;
+ logging_device->executed_flag = executed_flag;
+ logging_device->device_name = name;
+ logging_device->underlying_device =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+ *device_info = reinterpret_cast(logging_device);
+}
diff --git a/tensorflow/c/eager/custom_device_testutil.h b/tensorflow/c/eager/custom_device_testutil.h
new file mode 100644
index 00000000000..509df7d3e3e
--- /dev/null
+++ b/tensorflow/c/eager/custom_device_testutil.h
@@ -0,0 +1,36 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
+#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
+
+// A simple logging device to test custom device registration.
+#include
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/tf_status.h"
+
+void RegisterLoggingDevice(TFE_Context* context, const char* name,
+ bool* arrived_flag, bool* executed_flag,
+ TF_Status* status);
+void AllocateLoggingDevice(const char* name, bool* arrived_flag,
+ bool* executed_flag, TFE_CustomDevice** device,
+ void** device_info);
+TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
+ TF_Status* status);
+
+#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD
new file mode 100644
index 00000000000..9d787d26433
--- /dev/null
+++ b/tensorflow/c/eager/parallel_device/BUILD
@@ -0,0 +1,38 @@
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+package(
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "parallel_device",
+ srcs = ["parallel_device.cc"],
+ hdrs = ["parallel_device.h"],
+ deps = [
+ "//tensorflow/c:c_api",
+ "//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_experimental",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:variant",
+ ],
+)
+
+tf_cc_test(
+ name = "parallel_device_test",
+ srcs = ["parallel_device_test.cc"],
+ deps = [
+ ":parallel_device",
+ "//tensorflow/c:c_api",
+ "//tensorflow/c:c_api_experimental",
+ "//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_experimental",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc
new file mode 100644
index 00000000000..bd5d8e777f2
--- /dev/null
+++ b/tensorflow/c/eager/parallel_device/parallel_device.cc
@@ -0,0 +1,597 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/parallel_device/parallel_device.h"
+
+#include
+
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
+#include "absl/types/variant.h"
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
+namespace tensorflow {
+namespace eager {
+namespace {
+
+// Functor for making unique_ptrs slightly more ergonomic. Using
+// decltype(delete_fn) in the unique_ptr's second template argument requires
+// passing a function pointer to delete_fn when constructing the unique_ptr.
+class TensorHandleDeleter {
+ public:
+ void operator()(TFE_TensorHandle* to_delete) const {
+ TFE_DeleteTensorHandle(to_delete);
+ }
+};
+
+using TensorHandlePtr = std::unique_ptr;
+
+class OpDeleter {
+ public:
+ void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
+};
+
+using OpPtr = std::unique_ptr;
+
+class ExecutorDeleter {
+ public:
+ void operator()(TFE_Executor* to_delete) const {
+ TFE_DeleteExecutor(to_delete);
+ }
+};
+
+using ExecutorPtr = std::unique_ptr;
+
+class ParallelTensor;
+
+using MaybeParallelTensorOwned =
+ absl::variant, TensorHandlePtr>;
+using MaybeParallelTensorUnowned =
+ absl::variant;
+
+// Creates a vector of `count` new executors (threads).
+std::vector MakeExecutors(size_t count) {
+ std::vector executors;
+ executors.reserve(count);
+ for (int i = 0; i < count; ++i) {
+ executors.emplace_back(TFE_NewExecutor(true /* is_async */));
+ }
+ return executors;
+}
+
+// A representation of the custom device passed in and out of the TFE custom
+// device APIs, providing context about the parallel device to
+// ParallelDeviceExecute.
+class ParallelDevice {
+ public:
+ ParallelDevice(const std::string& name,
+ const std::vector& devices);
+
+ // Helper to copy a tensor handle from another device once for each component
+ // of the ParallelDevice.
+ //
+ // Sets a bad status and returns a nullptr if `tensor` is already on the
+ // ParallelDevice, or if the individual copies fail.
+ std::unique_ptr CopyToParallelDevice(TFE_Context* context,
+ TFE_TensorHandle* tensor,
+ TF_Status* status) const;
+
+ // Takes a description of a single operation being executed on the
+ // ParallelDevice, and in turn runs one operation per component device with
+ // its corresponding inputs from the input ParallelTensors (or
+ // implicitly-mirrored tensors on other devices). Wraps the resulting
+ // per-device and per-output TFE_TensorHandles into one ParallelTensor per
+ // output of the original operation.
+ //
+ // `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
+ // un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
+ // requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
+ // tensor, but other operations will implicitly broadcast non-parallel input
+ // tensors across the ParallelDevice's component devices.
+ //
+ // Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
+ // pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
+ // causes `Execute` to return non-parallel tensors.
+ //
+ // Attributes are forwarded to executed operations unmodified.
+ //
+ // The returned optional has a value if and only if `status` evaluates to
+ // TF_OK.
+ absl::optional> Execute(
+ TFE_Context* context, std::vector inputs,
+ const char* operation_name, const TFE_OpAttrs* attributes,
+ int expected_max_outputs, TF_Status* status) const;
+
+ // Implements the parallel case for `Execute`, where all of the outputs of the
+ // operation are ParallelTensors, and all inputs are either ParallelTensors or
+ // should be implicitly broadcast. This means the operation is not
+ // TPUReplicatedInput or TPUReplicatedOutput.
+ //
+ // The returned optional has a value if and only if `status` evaluates to
+ // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
+ // if sanity checks on dtypes/metadata fail.
+ absl::optional>>
+ ExecuteParallelOperation(TFE_Context* context,
+ std::vector inputs,
+ const char* operation_name,
+ const TFE_OpAttrs* attributes,
+ int expected_max_outputs, TF_Status* status) const;
+
+ const std::string& device_name() const { return device_name_; }
+
+ private:
+ // The name of the parallel device
+ // (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
+ const std::string device_name_;
+ // A sequence of device names, indicating which devices replicated operations
+ // are forwarded to.
+ const std::vector underlying_devices_;
+ // A sequence of TFE_Executors, one per device, for executing operations in
+ // parallel.
+ const std::vector executors_;
+};
+
+// The internal representation of a TFE_TensorHandle placed on a
+// ParallelDevice. Contains a tuple of tensors, one on each of the
+// `underlying_devices_` of the ParallelDevice.
+class ParallelTensor {
+ public:
+ // Construct a ParallelTensor from TensorHandles placed on the component
+ // devices of a ParallelDevice.
+ static std::unique_ptr FromTensorHandles(
+ const ParallelDevice& parallel_device,
+ std::vector components, TF_Status* status);
+
+ // Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
+ static TensorHandlePtr AsTensorHandle(TFE_Context* context,
+ std::unique_ptr t,
+ TF_Status* status);
+
+ size_t num_tensors() const { return tensors_.size(); }
+ TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
+
+ private:
+ ParallelTensor(const ParallelDevice& device,
+ std::vector tensors,
+ std::vector shape, const TF_DataType dtype)
+ : device_(device),
+ tensors_(std::move(tensors)),
+ shape_(std::move(shape)),
+ dtype_(dtype) {}
+
+ const ParallelDevice& device_;
+ const std::vector tensors_;
+ const std::vector shape_;
+ const TF_DataType dtype_;
+};
+
+ParallelDevice::ParallelDevice(const std::string& name,
+ const std::vector& devices)
+ : device_name_(name),
+ underlying_devices_(devices),
+ executors_(MakeExecutors(underlying_devices_.size())) {}
+
+std::unique_ptr ParallelDevice::CopyToParallelDevice(
+ TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
+ const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
+ if (device_name_ == current_device) {
+ std::string message(absl::StrCat(
+ "Tried to copy a TensorHandle to its existing device: ", device_name_));
+ TF_SetStatus(status, TF_INTERNAL, message.c_str());
+ return nullptr;
+ }
+ std::vector components;
+ components.reserve(underlying_devices_.size());
+ for (const std::string& underlying_device_name : underlying_devices_) {
+ TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
+ tensor, context, underlying_device_name.c_str(), status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ components.emplace_back(t);
+ }
+ return ParallelTensor::FromTensorHandles(*this, std::move(components),
+ status);
+}
+
+absl::optional> ParallelDevice::Execute(
+ TFE_Context* context, std::vector inputs,
+ const char* operation_name, const TFE_OpAttrs* attributes,
+ int expected_max_outputs, TF_Status* status) const {
+ absl::optional> result;
+ // TODO(allenl): We should remove "TPU" from these op names at the very least,
+ // or consider other ways of packing/unpacking parallel tensors.
+ if (operation_name == std::string("TPUReplicatedInput")) {
+ // Special-cased operation for packing per-device tensors into one parallel
+ // tensor.
+ if (inputs.size() != underlying_devices_.size()) {
+ std::string message(absl::StrCat(
+ "The parallel device ", device_name_, " expected ",
+ underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
+ inputs.size()));
+ TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
+ return result;
+ }
+ std::vector components;
+ components.reserve(inputs.size());
+ for (int i = 0; i < inputs.size(); ++i) {
+ if (absl::holds_alternative(inputs[i])) {
+ std::string message(absl::StrCat(
+ "Expected all inputs to TPUReplicatedInput to be non-parallel "
+ "TensorHandles. The input ",
+ i,
+ " was a parallel tensor (already "
+ "placed on the parallel device)."));
+ TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
+ return result;
+ }
+ components.emplace_back(TFE_TensorHandleCopySharingTensor(
+ absl::get(inputs[i]), status));
+ }
+ std::vector result_content;
+ result_content.reserve(1);
+ result_content.push_back(ParallelTensor::FromTensorHandles(
+ *this, std::move(components), status));
+ if (TF_GetCode(status) != TF_OK) return result;
+ result.emplace(std::move(result_content));
+ return result;
+ } else if (operation_name == std::string("TPUReplicatedOutput")) {
+ // Special-cased operation for un-packing one parallel tensor into
+ // per-device tensors.
+ OpPtr op(TFE_NewOp(context, operation_name, status));
+ TFE_OpAddAttrs(op.get(), attributes);
+ int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
+ if (TF_GetCode(status) != TF_OK) return result;
+ if (expected_outputs != underlying_devices_.size()) {
+ std::string message(absl::StrCat(
+ "The parallel device ", device_name_, " expected ",
+ underlying_devices_.size(),
+ " outputs for TPUReplicatedOutput, but got ", expected_outputs));
+ TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
+ return result;
+ }
+ if (absl::holds_alternative(inputs[0])) {
+ TF_SetStatus(status, TF_INVALID_ARGUMENT,
+ "Expected the input to "
+ "TPUReplicatedOutput to be a parallel tensor (placed on the "
+ "parallel device).");
+ return result;
+ }
+ ParallelTensor* t = absl::get(inputs[0]);
+ std::vector outputs;
+ outputs.reserve(t->num_tensors());
+ for (int i = 0; i < t->num_tensors(); ++i) {
+ TensorHandlePtr this_output(
+ TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
+ outputs.emplace_back(std::move(this_output));
+ if (TF_GetCode(status) != TF_OK) return result;
+ }
+ result.emplace(std::move(outputs));
+ return result;
+ }
+ absl::optional>>
+ maybe_parallel_results(
+ ExecuteParallelOperation(context, std::move(inputs), operation_name,
+ attributes, expected_max_outputs, status));
+ if (!maybe_parallel_results.has_value()) return result;
+ std::vector> parallel_results(
+ std::move(maybe_parallel_results.value()));
+ std::vector result_content;
+ result_content.reserve(parallel_results.size());
+ for (std::unique_ptr& parallel_result : parallel_results) {
+ result_content.push_back(
+ MaybeParallelTensorOwned(std::move(parallel_result)));
+ }
+ result.emplace(std::move(result_content));
+ return result;
+}
+
+absl::optional>>
+ParallelDevice::ExecuteParallelOperation(
+ TFE_Context* context, std::vector inputs,
+ const char* operation_name, const TFE_OpAttrs* attributes,
+ int expected_max_outputs, TF_Status* status) const {
+ absl::optional>> result;
+ // Compute per-device per-output tensors
+ std::vector> per_device_output_tensors;
+ per_device_output_tensors.reserve(underlying_devices_.size());
+ // TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
+ // setting the thread-local executor like this.
+ TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
+ auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
+ TFE_ContextSetExecutorForThread(context, previous_executor);
+ TFE_DeleteExecutor(previous_executor);
+ });
+ int first_op_output_count;
+ for (int device_index = 0; device_index < underlying_devices_.size();
+ ++device_index) {
+ TFE_Executor* executor = executors_[device_index].get();
+ // Note that the `reset_executor` cleanup sets the thread's executor back to
+ // the value before this function ran.
+ TFE_ContextSetExecutorForThread(context, executor);
+ OpPtr op(TFE_NewOp(context, operation_name, status));
+ if (TF_GetCode(status) != TF_OK) return result;
+ TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
+ status);
+ TFE_OpAddAttrs(op.get(), attributes);
+ for (int input_index = 0; input_index < inputs.size(); ++input_index) {
+ if (absl::holds_alternative(inputs[input_index])) {
+ // Non-parallel tensors are implicitly broadcast, i.e. set as the input
+ // to each parallel operation.
+ //
+ // TODO(allenl): There may be smarter ways to do this copy in some
+ // cases, i.e. with a collective broadcast. We'll need to be careful
+ // about things that are taken as inputs on the host or on their
+ // existing device (for multi-device functions).
+ TFE_OpAddInput(op.get(),
+ absl::get(inputs[input_index]),
+ status);
+ if (TF_GetCode(status) != TF_OK) return result;
+ } else {
+ // Parallel tensors are divided between operations by device.
+ TFE_OpAddInput(op.get(),
+ absl::get(inputs[input_index])
+ ->tensor(device_index),
+ status);
+ if (TF_GetCode(status) != TF_OK) return result;
+ }
+ }
+ std::vector op_outputs(expected_max_outputs);
+ int real_num_outputs = expected_max_outputs;
+ // For nested devices, the inner device sees the async executor we've
+ // set. Inner parallel devices will just overwrite this with their own and
+ // then set it back to ours before returning. This means parallel devices
+ // which consist of several aliased parallel devices would hypothetically
+ // deadlock if the outer parallel device ran one collective with a group
+ // size equal to the total number of aliased physical devices. Currently
+ // physical devices cannot participate in a single collective reduction
+ // multiple times, so this would fail earlier.
+ //
+ // TODO(allenl): Keep a map from outer executor to list of inner executors
+ // rather than a single list of executors so aliased nested parallel devices
+ // don't re-use an executor.
+ TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
+ if (device_index == 0) {
+ first_op_output_count = real_num_outputs;
+ } else {
+ if (real_num_outputs != first_op_output_count) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "Parallel ops produced different numbers of tensors.");
+ return result;
+ }
+ }
+ if (TF_GetCode(status) != TF_OK) return result;
+ std::vector this_outputs;
+ this_outputs.reserve(real_num_outputs);
+ for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
+ this_outputs.emplace_back(op_outputs[output_num]);
+ }
+ per_device_output_tensors.push_back(std::move(this_outputs));
+ }
+ // For each output of the original operation, pack the per-device
+ // TensorHandles we've computed into a single parallel TensorHandle.
+ std::vector> per_device_outputs;
+ per_device_outputs.reserve(first_op_output_count);
+ for (int i = 0; i < first_op_output_count; ++i) {
+ std::vector components;
+ components.reserve(underlying_devices_.size());
+ for (int j = 0; j < underlying_devices_.size(); ++j) {
+ components.push_back(std::move(per_device_output_tensors[j][i]));
+ }
+ per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
+ *this, std::move(components), status));
+ if (TF_GetCode(status) != TF_OK) return result;
+ }
+ result.emplace(std::move(per_device_outputs));
+ return result;
+}
+
+std::unique_ptr ParallelTensor::FromTensorHandles(
+ const ParallelDevice& parallel_device,
+ std::vector components, TF_Status* status) {
+ TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
+ std::vector shape(
+ TFE_TensorHandleNumDims(components[0].get(), status));
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ for (int i = 0; i < shape.size(); ++i) {
+ shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ }
+
+ // Verify that the TensorHandle's shape and dtype match all of the component
+ // shapes and dtypes.
+ for (TensorHandlePtr& component : components) {
+ for (int i = 0; i < shape.size(); ++i) {
+ int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ if (tensor_dim != shape[i]) {
+ // TODO(allenl): Allow shapes to differ.
+ TF_SetStatus(status, TF_UNIMPLEMENTED,
+ "Components of a ParallelTensor must currently all have "
+ "the same shape");
+ return nullptr;
+ }
+ if (TFE_TensorHandleDataType(component.get()) != dtype) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "Components of a ParallelTensor must all have "
+ "the same dtype");
+ return nullptr;
+ }
+ }
+ }
+
+ return std::unique_ptr(new ParallelTensor(
+ parallel_device, std::move(components), std::move(shape), dtype));
+}
+
+// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
+// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
+// reference counts drop to zero.
+void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
+ delete reinterpret_cast(data);
+}
+
+TensorHandlePtr ParallelTensor::AsTensorHandle(
+ TFE_Context* context, std::unique_ptr t,
+ TF_Status* status) {
+ // The resulting TensorHandle owns an opaque pointer to "device memory", which
+ // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
+ // deleted, it will call ParallelTensorDeallocator to free the struct.
+ ParallelTensor* t_released = t.release();
+ return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
+ context, t_released->device_.device_name().c_str(), t_released->dtype_,
+ t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
+ &ParallelTensorDeallocator, nullptr, status));
+}
+
+// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
+// registration.
+//
+// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
+// a ParallelTensor with one copy of `tensor` for each device in the
+// ParallelDevice.
+//
+// Since this function is used to satisfy the TFE_CustomDevice C API,
+// device_info is passed in using a C-style generic. It must always be a
+// ParallelDevice.
+TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
+ TFE_TensorHandle* tensor,
+ TF_Status* status, void* device_info) {
+ ParallelDevice* dev = reinterpret_cast(device_info);
+ std::unique_ptr parallel_tensor(
+ dev->CopyToParallelDevice(context, tensor, status));
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
+ status)
+ .release();
+}
+
+// For TFE_CustomDevice::copy_tensor_from_device in the parallel device
+// registration.
+//
+// Currently this is an error, and un-packing ParallelTensors must be performed
+// explicitly by running a TPUReplicatedOutput operation on the parallel device.
+//
+// TODO(allenl): There are some use-cases that are only supported by copying to
+// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
+// need to return something here or address these use-cases one by one.
+TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
+ TFE_TensorHandle* tensor,
+ const char* target_device_name,
+ TF_Status* status,
+ void* device_info) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "Trying to copy a tensor out of a parallel device.");
+ return nullptr;
+}
+
+// For TFE_CustomDevice::execute in the parallel device registration.
+//
+// Since this function is used to satisfy the TFE_CustomDevice C API,
+// device_info is passed in using a C-style generic. It must always be a
+// ParallelDevice.
+void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
+ TFE_TensorHandle** inputs,
+ const char* operation_name,
+ const TFE_OpAttrs* attributes, int* num_outputs,
+ TFE_TensorHandle** outputs, TF_Status* status,
+ void* device_info) {
+ ParallelDevice* dev = reinterpret_cast(device_info);
+ std::vector typed_inputs;
+ typed_inputs.reserve(num_inputs);
+ for (int i = 0; i < num_inputs; ++i) {
+ const char* tensor_handle_device =
+ TFE_TensorHandleDeviceName(inputs[i], status);
+ if (TF_GetCode(status) != TF_OK) return;
+ if (dev->device_name() == tensor_handle_device) {
+ // We assume that any tensors already placed on this device are
+ // ParallelTensors.
+ typed_inputs.emplace_back(reinterpret_cast(
+ TFE_TensorHandleDevicePointer(inputs[i], status)));
+ if (TF_GetCode(status) != TF_OK) return;
+ } else {
+ typed_inputs.emplace_back(inputs[i]);
+ }
+ }
+
+ absl::optional> maybe_typed_outputs(
+ dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
+ *num_outputs, status));
+ if (TF_GetCode(status) != TF_OK) return;
+ if (!maybe_typed_outputs.has_value()) {
+ TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
+ return;
+ }
+
+ std::vector typed_outputs(
+ std::move(maybe_typed_outputs.value()));
+
+ if (typed_outputs.size() > *num_outputs) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "The allocated output buffer was too small.");
+ return;
+ }
+
+ for (int i = 0; i < typed_outputs.size(); ++i) {
+ MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
+ if (absl::holds_alternative(typed_output)) {
+ outputs[i] = absl::get(typed_output).release();
+ } else {
+ outputs[i] = ParallelTensor::AsTensorHandle(
+ context,
+ std::move(absl::get>(
+ typed_output)),
+ status)
+ .release();
+ if (TF_GetCode(status) != TF_OK) return;
+ }
+ }
+ *num_outputs = typed_outputs.size();
+}
+
+// For TFE_CustomDevice::delete_device in the parallel device registration.
+//
+// Since this function is used to satisfy the TFE_CustomDevice C API,
+// device_info is passed in using a C-style generic. It must always be a
+// ParallelDevice.
+void DeleteParallelDevice(void* device_info) {
+ delete reinterpret_cast(device_info);
+}
+
+} // namespace
+
+void RegisterParallelDevice(TFE_Context* context, const char* device_name,
+ const char** underlying_devices,
+ int num_underlying_devices, TF_Status* status) {
+ TFE_CustomDevice custom_device;
+ custom_device.copy_tensor_to_device = &CopyToParallelDevice;
+ custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
+ custom_device.delete_device = &DeleteParallelDevice;
+ custom_device.execute = &ParallelDeviceExecute;
+ std::vector underlying_devices_vector;
+ underlying_devices_vector.reserve(num_underlying_devices);
+ for (int device_index = 0; device_index < num_underlying_devices;
+ ++device_index) {
+ underlying_devices_vector.push_back(underlying_devices[device_index]);
+ }
+ ParallelDevice* d =
+ new ParallelDevice(device_name, underlying_devices_vector);
+ TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
+}
+
+} // namespace eager
+} // namespace tensorflow
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.h b/tensorflow/c/eager/parallel_device/parallel_device.h
new file mode 100644
index 00000000000..b106524401f
--- /dev/null
+++ b/tensorflow/c/eager/parallel_device/parallel_device.h
@@ -0,0 +1,62 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
+#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
+
+#include "tensorflow/c/eager/c_api.h"
+
+namespace tensorflow {
+namespace eager {
+
+// Register a parallel device named `device_name` which forwards operations to
+// `underlying_devices`, maintaining "parallel tensors" with components placed
+// on each underlying device.
+//
+// For example if `device_name` is
+// "/job:localhost/replica:0/task:0/device:CUSTOM:0"
+// and `underlying_devices` is
+// {"/job:localhost/replica:0/task:0/device:GPU:0",
+// "/job:localhost/replica:0/task:0/device:GPU:1"}
+// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1.
+//
+// Implicit copies onto `device_name` are allowed, replicating the value once
+// per device in `underlying_devices`. Implicit copies off of the device throw
+// an error.
+//
+// All component tensors must have the same dtype. Currently they must also have
+// the same shape, although this requirement may be relaxed in the future.
+//
+// `device_name` must not name an existing physical or custom device (see
+// the documentation for TFE_RegisterCustomDevice for more information).
+//
+// Tensors may be copied on or off the device explicitly using
+// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with
+// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the
+// parallel device creates a parallel tensor `x` with `a` on the first of
+// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked =
+// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
+// into its components.
+//
+// `context` owns the parallel device. `underlying_devices` must stay valid
+// while the parallel device is in use.
+void RegisterParallelDevice(TFE_Context* context, const char* device_name,
+ const char** underlying_devices,
+ int num_underlying_devices, TF_Status* status);
+
+} // namespace eager
+} // namespace tensorflow
+
+#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
new file mode 100644
index 00000000000..41c7d64e231
--- /dev/null
+++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
@@ -0,0 +1,917 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/parallel_device/parallel_device.h"
+
+#include
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_experimental.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/core/platform/test.h"
+
+// NOTE(allenl): These tests currently go through TFE_Execute and so are
+// integration testing rather than purely testing the parallel device. They
+// correspond fairly well to the implementation, but testing the C++ directly is
+// another option.
+
+// Functor for making unique_ptr to TFE_TensorHandle slightly more
+// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
+// template argument requires passing a function pointer to
+// TFE_DeleteTensorHandle when constructing the unique_ptr.
+class TensorHandleDeleter {
+ public:
+ void operator()(TFE_TensorHandle* to_delete) {
+ TFE_DeleteTensorHandle(to_delete);
+ }
+};
+
+using TensorHandlePtr = std::unique_ptr;
+
+// A helper for performing common operations on variables. A much more
+// restricted stand-in for tf.Variable in Python.
+class Variable {
+ public:
+ // Construct a Variable from a resource-dtype TFE_TensorHandle and an
+ // indication of the dtype of the variable's value.
+ //
+ // Note that creating this resource-dtype handle can fail, so `Create` is a
+ // separate static method which returns a status.
+ Variable(TFE_TensorHandle* handle, TF_DataType type)
+ : handle_(handle), type_(type) {}
+
+ // Helper for constructing a resource handle and wrapping it in a `Variable`
+ // object.
+ static Variable* Create(TFE_Context* context, TF_DataType type,
+ const int64_t* dims, const int num_dims,
+ const char* device, TF_Status* status);
+ // Dereferences the backing buffer for the variable. Note that since this can
+ // fail (it runs operations), it must be called explicitly and the resulting
+ // `status` checked.
+ void Destroy(TFE_Context* context, TF_Status* status);
+
+ // Reads from the variable.
+ TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
+ // Assigns a new value to the variable.
+ void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
+ // Adds `value` to the existing value of the variable.
+ void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
+ TF_Status* status);
+
+ private:
+ // Helper for running any single-argument assignment ops (Assign, AssignAdd,
+ // AssignSub, ...).
+ void GeneralAssignment(const char* op_name, TFE_Context* context,
+ TFE_TensorHandle* value, TF_Status* status);
+
+ // The a handle for the resource-dtype tensor pointing to the variable's
+ // buffer.
+ TFE_TensorHandle* handle_;
+ // The dtype of the variable's buffer (input dtype for assignments, output
+ // dtype of read operations).
+ TF_DataType type_;
+};
+
+Variable* Variable::Create(TFE_Context* context, TF_DataType type,
+ const int64_t* dims, const int num_dims,
+ const char* device, TF_Status* status) {
+ std::unique_ptr op(
+ TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrType(op.get(), "dtype", type);
+ TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
+ TFE_OpSetAttrString(op.get(), "container", "", 0);
+ // Use the special GUID for no buffer sharing
+ //
+ // TODO(allenl): Should we provide a better API for this? AFAIK this is the
+ // only reasonable way to make variables with no aliasing using the eager C
+ // API.
+ std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
+ TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
+ no_sharing.length());
+ TFE_OpSetDevice(op.get(), device, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_TensorHandle* var_handle = nullptr;
+ int num_retvals = 1;
+ TFE_Execute(op.get(), &var_handle, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ return new Variable(var_handle, type);
+}
+
+void Variable::Destroy(TFE_Context* context, TF_Status* status) {
+ // Free the backing buffer for the variable.
+ std::unique_ptr op(
+ TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
+ if (TF_GetCode(status) != TF_OK) return;
+ TFE_OpAddInput(op.get(), handle_, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ const char* device = TFE_TensorHandleDeviceName(handle_, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ TFE_OpSetDevice(op.get(), device, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ int num_retvals = 0;
+ TFE_Execute(op.get(), nullptr, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ // Delete the variable handle itself.
+ TFE_DeleteTensorHandle(handle_);
+}
+
+TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
+ std::unique_ptr op(
+ TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpAddInput(op.get(), handle_, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ const char* device = TFE_TensorHandleDeviceName(handle_, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetDevice(op.get(), device, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrType(op.get(), "dtype", type_);
+ int num_retvals = 1;
+ TFE_TensorHandle* var_value = nullptr;
+ TFE_Execute(op.get(), &var_value, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ return TensorHandlePtr(var_value);
+}
+
+void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
+ TFE_TensorHandle* value, TF_Status* status) {
+ std::unique_ptr op(
+ TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
+ if (TF_GetCode(status) != TF_OK) return;
+ TFE_OpSetAttrType(op.get(), "dtype", type_);
+ TFE_OpAddInput(op.get(), handle_, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ TFE_OpAddInput(op.get(), value, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ const char* device = TFE_TensorHandleDeviceName(handle_, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ TFE_OpSetDevice(op.get(), device, status);
+
+ int num_retvals = 0;
+ TFE_Execute(op.get(), nullptr, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return;
+}
+
+void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
+ TF_Status* status) {
+ GeneralAssignment("AssignAddVariableOp", context, value, status);
+}
+
+void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
+ TF_Status* status) {
+ GeneralAssignment("AssignVariableOp", context, value, status);
+}
+
+// Passed to `TF_NewTensor` to indicate how an array of floats should be
+// deleted.
+static void FloatDeallocator(void* data, size_t, void* arg) {
+ delete[] static_cast(data);
+}
+
+// Creates a TFE_TensorHandle with value `v`.
+TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
+ const int num_bytes = sizeof(float);
+ float* values = new float[1];
+ values[0] = v;
+ std::unique_ptr tensor(
+ TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
+ nullptr),
+ TF_DeleteTensor);
+ return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
+}
+
+// Creates a rank-one TFE_TensorHandle with value `v`.
+TensorHandlePtr VectorFloatTensorHandle(const std::vector& v,
+ TF_Status* status) {
+ const int num_bytes = v.size() * sizeof(float);
+ float* values = new float[v.size()];
+ memcpy(values, v.data(), num_bytes);
+ int64_t dims = v.size();
+ std::unique_ptr tensor(
+ TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
+ &FloatDeallocator, nullptr),
+ TF_DeleteTensor);
+ return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
+}
+
+// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
+template
+void ExtractPerDeviceValues(
+ TFE_Context* context, TFE_TensorHandle* input,
+ std::array* components, TF_Status* status) {
+ std::unique_ptr op(
+ TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
+ if (TF_GetCode(status) != TF_OK) return;
+ TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
+ TFE_OpAddInput(op.get(), input, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ const char* device = TFE_TensorHandleDeviceName(input, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ TFE_OpSetDevice(op.get(), device, status);
+ if (TF_GetCode(status) != TF_OK) return;
+
+ TFE_TensorHandle* result_handles[num_replicas];
+ int num_retvals = num_replicas;
+ TFE_Execute(op.get(), result_handles, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ for (int i = 0; i < num_replicas; ++i) {
+ (*components)[i].reset(result_handles[i]);
+ }
+}
+
+// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
+template
+TensorHandlePtr CreatePerDeviceValues(
+ TFE_Context* context,
+ const std::array& components,
+ const char* device, TF_Status* status) {
+ std::unique_ptr op(
+ TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrInt(op.get(), "N", num_replicas);
+ for (int i = 0; i < num_replicas; ++i) {
+ TFE_OpAddInput(op.get(), components[i], status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ }
+ TFE_OpSetDevice(op.get(), device, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+
+ TFE_TensorHandle* result_handle;
+ int num_retvals = 1;
+ TFE_Execute(op.get(), &result_handle, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ return TensorHandlePtr(result_handle);
+}
+
+TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
+ TFE_TensorHandle* second, TF_Status* status) {
+ std::unique_ptr op(
+ TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpAddInput(op.get(), first, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpAddInput(op.get(), second, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ const char* first_device = TFE_TensorHandleDeviceName(first, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetDevice(op.get(), first_device, status);
+
+ TFE_TensorHandle* result_handle;
+ int num_retvals = 1;
+ TFE_Execute(op.get(), &result_handle, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ return TensorHandlePtr(result_handle);
+}
+
+// Assert that `handle` is equal to `expected_value`.
+void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr value_zero(
+ TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ ASSERT_EQ(expected_value,
+ *static_cast(TF_TensorData(value_zero.get())));
+}
+
+// Create and modify a variable placed on a parallel device which composes
+// `first_device` and `second_device`.
+void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
+ const char* second_device) {
+ // Register the custom device
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
+ std::array underlying_devices{first_device, second_device};
+ tensorflow::eager::RegisterParallelDevice(
+ context, device_name, underlying_devices.data(),
+ underlying_devices.size(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Create a variable handle (uninitialized to start) placed on the parallel
+ // device.
+ std::function variable_deleter = [&](Variable* to_delete) {
+ to_delete->Destroy(context, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ delete to_delete;
+ };
+ std::unique_ptr variable(
+ Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
+ status.get()),
+ variable_deleter);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Assign an initial value to the variable, implicitly mirroring it to each
+ // component device.
+ {
+ TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ variable->Assign(context, initial_value.get(), status.get());
+ }
+
+ // Read from the variable and verify that we have a parallel tensor.
+ {
+ TensorHandlePtr read = variable->Read(context, status.get());
+ std::array components;
+ ExtractPerDeviceValues(context, read.get(), &components, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ AssertScalarFloatEq(components[0].get(), 20.);
+ AssertScalarFloatEq(components[1].get(), 20.);
+
+ std::string first_device =
+ TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
+ ASSERT_EQ(underlying_devices[0], first_device);
+ std::string second_device =
+ TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
+ ASSERT_EQ(underlying_devices[1], second_device);
+ }
+
+ // Add a parallel tensor with different values on each device to the variable.
+ {
+ TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
+ TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
+ std::array components{value_one.get(),
+ value_two.get()};
+ TensorHandlePtr combined_value =
+ CreatePerDeviceValues(context, components, device_name, status.get());
+ variable->AssignAdd(context, combined_value.get(), status.get());
+ }
+
+ // Read the variable and verify that each component has the right modified
+ // value.
+ {
+ TensorHandlePtr read = variable->Read(context, status.get());
+ std::array components;
+ ExtractPerDeviceValues(context, read.get(), &components, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ AssertScalarFloatEq(components[0].get(), 23.);
+ AssertScalarFloatEq(components[1].get(), 18.);
+
+ std::string first_device =
+ TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
+ ASSERT_EQ(underlying_devices[0], first_device);
+ std::string second_device =
+ TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
+ ASSERT_EQ(underlying_devices[1], second_device);
+ }
+}
+
+TEST(PARALLEL_DEVICE, TestBasicCPU) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr config(
+ TF_CreateConfig(
+ /*xla*/ false,
+ /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
+ 2),
+ TF_DeleteBuffer);
+ TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
+ status.get());
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ BasicTestsForTwoDevices(context.get(),
+ "/job:localhost/replica:0/task:0/device:CPU:0",
+ "/job:localhost/replica:0/task:0/device:CPU:1");
+}
+
+TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ BasicTestsForTwoDevices(context.get(),
+ "/job:localhost/replica:0/task:0/device:CPU:0",
+ "/job:localhost/replica:0/task:0/device:CPU:0");
+}
+
+TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Skip the test if no TPU is available.
+ std::unique_ptr devices(
+ TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ bool has_tpu = false;
+ for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
+ ++device_index) {
+ std::string device_type =
+ TF_DeviceListType(devices.get(), device_index, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ if (device_type == "TPU") {
+ has_tpu = true;
+ break;
+ }
+ }
+ if (has_tpu) {
+ BasicTestsForTwoDevices(context.get(),
+ "/job:localhost/replica:0/task:0/device:TPU:0",
+ "/job:localhost/replica:0/task:0/device:TPU:0");
+ }
+}
+
+TEST(PARALLEL_DEVICE, TestExplicitCopies) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr config(
+ TF_CreateConfig(
+ /*xla*/ false,
+ /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
+ 2),
+ TF_DeleteBuffer);
+ TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
+ status.get());
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
+ std::vector underlying_devices;
+ const char* first_device_name =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+ underlying_devices.push_back(first_device_name);
+ const char* second_device_name =
+ "/job:localhost/replica:0/task:0/device:CPU:1";
+ underlying_devices.push_back(second_device_name);
+ tensorflow::eager::RegisterParallelDevice(
+ context.get(), device_name, underlying_devices.data(),
+ underlying_devices.size(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Copying on to a parallel device is OK.
+ TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
+ cpu_value.get(), context.get(), device_name, status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ const char* backing_device =
+ TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ ASSERT_EQ(std::string(device_name), backing_device);
+
+ // Un-pack the parallel tensor to verify that the copy was successful.
+ {
+ std::array components;
+ ExtractPerDeviceValues(context.get(), device_value.get(), &components,
+ status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // The value of the original tensor is replicated on each device.
+ AssertScalarFloatEq(components[0].get(), 3.);
+ AssertScalarFloatEq(components[1].get(), 3.);
+
+ // Verify that the mirrors are placed on the component devices.
+ std::string first_device =
+ TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
+ ASSERT_EQ(underlying_devices[0], first_device);
+ std::string second_device =
+ TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
+ ASSERT_EQ(underlying_devices[1], second_device);
+ }
+
+ // Copies off of parallel devices must be explicit.
+ TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
+ device_value.get(), context.get(), first_device_name, status.get()));
+ ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL);
+}
+
+TEST(PARALLEL_DEVICE, TestDifferentShapes) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr config(
+ TF_CreateConfig(
+ /*xla*/ false,
+ /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
+ 2),
+ TF_DeleteBuffer);
+ TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
+ status.get());
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
+ std::vector underlying_devices;
+ underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
+ underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
+ tensorflow::eager::RegisterParallelDevice(
+ context.get(), device_name, underlying_devices.data(),
+ underlying_devices.size(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Create two vectors with different lengths
+ std::vector size_two_value{1., 2.};
+ std::vector size_three_value{1., 2., 3.};
+ TensorHandlePtr size_two(
+ VectorFloatTensorHandle(size_two_value, status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TensorHandlePtr size_three(
+ VectorFloatTensorHandle(size_three_value, status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Try to combine these values into a single parallel tensor.
+ std::array components{size_two.get(), size_three.get()};
+ TensorHandlePtr combined_value = CreatePerDeviceValues(
+ context.get(), components, device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED)
+ << TF_Message(status.get());
+}
+
+TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr config(
+ TF_CreateConfig(
+ /*xla*/ false,
+ /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
+ 3),
+ TF_DeleteBuffer);
+ TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
+ status.get());
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Create a parallel device with two CPUs
+ const char* first_device_name =
+ "/job:localhost/replica:0/task:0/device:CUSTOM:0";
+ std::vector first_underlying_devices{
+ "/job:localhost/replica:0/task:0/device:CPU:0",
+ "/job:localhost/replica:0/task:0/device:CPU:1"};
+ tensorflow::eager::RegisterParallelDevice(
+ context.get(), first_device_name, first_underlying_devices.data(),
+ first_underlying_devices.size(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Create a second parallel device with the first parallel device and one
+ // additional CPU.
+ const char* second_device_name =
+ "/job:localhost/replica:0/task:0/device:CUSTOM:1";
+ std::vector second_underlying_devices{
+ "/job:localhost/replica:0/task:0/device:CUSTOM:0",
+ "/job:localhost/replica:0/task:0/device:CPU:2"};
+ tensorflow::eager::RegisterParallelDevice(
+ context.get(), second_device_name, second_underlying_devices.data(),
+ second_underlying_devices.size(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Create a tensor on the first parallel device
+ TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
+ TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ std::array components{value_one.get(), value_two.get()};
+ TensorHandlePtr first_combined_value = CreatePerDeviceValues(
+ context.get(), components, first_device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Nest the first parallel tensor into a second
+ TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ components[0] = first_combined_value.get();
+ components[1] = value_three.get();
+ TensorHandlePtr second_combined_value = CreatePerDeviceValues(
+ context.get(), components, second_device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TensorHandlePtr multiply_result(Multiply(context.get(),
+ second_combined_value.get(),
+ negative_one.get(), status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Un-pack the parallel tensor to verify that the operation was
+ // successful. The resulting structure should be:
+ // second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
+ std::array second_components;
+ ExtractPerDeviceValues(context.get(), multiply_result.get(),
+ &second_components, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ AssertScalarFloatEq(second_components[1].get(), 9.);
+
+ // Verify that the mirrors are placed on the component devices.
+ std::string first_device = TFE_TensorHandleBackingDeviceName(
+ second_components[0].get(), status.get());
+ ASSERT_EQ(second_underlying_devices[0], first_device);
+ std::string second_device = TFE_TensorHandleBackingDeviceName(
+ second_components[1].get(), status.get());
+ ASSERT_EQ(second_underlying_devices[1], second_device);
+
+ // Un-pack the first parallel device's tensor too
+ std::array first_components;
+ ExtractPerDeviceValues(context.get(), second_components[0].get(),
+ &first_components, status.get());
+ AssertScalarFloatEq(first_components[0].get(), 3.);
+ AssertScalarFloatEq(first_components[1].get(), 6.);
+
+ first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
+ status.get());
+ ASSERT_EQ(first_underlying_devices[0], first_device);
+ second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
+ status.get());
+ ASSERT_EQ(first_underlying_devices[1], second_device);
+}
+
+TEST(PARALLEL_DEVICE, TestInvalidPacking) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
+ std::vector underlying_devices;
+ underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
+ tensorflow::eager::RegisterParallelDevice(
+ context.get(), device_name, underlying_devices.data(),
+ underlying_devices.size(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
+ TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
+ {
+ // Try to pack two TensorHandles onto a parallel device with a single
+ // component.
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ std::array components{value_one.get(),
+ value_two.get()};
+ TensorHandlePtr combined_value = CreatePerDeviceValues(
+ context.get(), components, device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
+ << TF_Message(status.get());
+ }
+
+ {
+ // Try to extract the wrong number of components from a parallel tensor
+ std::array correct_components{value_one.get()};
+ TensorHandlePtr combined_value = CreatePerDeviceValues(
+ context.get(), correct_components, device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ std::array incorrect_components;
+ ExtractPerDeviceValues(context.get(), combined_value.get(),
+ &incorrect_components, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
+ << TF_Message(status.get());
+ }
+
+ {
+ // Try to pass a ParallelTensor to TPUReplicatedInput
+ std::array correct_components{value_one.get()};
+ TensorHandlePtr combined_value = CreatePerDeviceValues(
+ context.get(), correct_components, device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ std::array incorrect_components{combined_value.get()};
+ TensorHandlePtr recombined_value = CreatePerDeviceValues(
+ context.get(), incorrect_components, device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
+ << TF_Message(status.get());
+ }
+
+ {
+ // Try to pass a non-parallel tensor to TPUReplicatedOutput
+ std::unique_ptr op(
+ TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
+ TFE_DeleteOp);
+ if (TF_GetCode(status.get()) != TF_OK) return;
+ TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
+ TFE_OpAddInput(op.get(), value_one.get(), status.get());
+ if (TF_GetCode(status.get()) != TF_OK) return;
+ TFE_OpSetDevice(op.get(), device_name, status.get());
+ if (TF_GetCode(status.get()) != TF_OK) return;
+
+ TFE_TensorHandle* result_handles;
+ int num_retvals = 1;
+ TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
+ << TF_Message(status.get());
+ }
+}
+
+TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
+ int group_size, TF_Status* status) {
+ std::unique_ptr op(
+ TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+
+ const char* device = TFE_TensorHandleDeviceName(input, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetDevice(op.get(), device, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
+ TFE_OpSetAttrInt(op.get(), "group_size", group_size);
+ TFE_OpSetAttrInt(op.get(), "group_key", 0);
+ TFE_OpSetAttrInt(op.get(), "instance_key", 0);
+ const std::string merge_op("Add");
+ TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
+ merge_op.length());
+ const std::string final_op("Id");
+ TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
+ final_op.length());
+ TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
+
+ TFE_OpAddInput(op.get(), input, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+
+ TFE_TensorHandle* result_handle;
+ int num_retvals = 1;
+ TFE_Execute(op.get(), &result_handle, &num_retvals, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ return TensorHandlePtr(result_handle);
+}
+
+TEST(PARALLEL_DEVICE, TestCollective) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr config(
+ TF_CreateConfig(
+ /*xla*/ false,
+ /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
+ 2),
+ TF_DeleteBuffer);
+ TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
+ status.get());
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
+ std::vector underlying_devices;
+ underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
+ underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
+ tensorflow::eager::RegisterParallelDevice(
+ context.get(), device_name, underlying_devices.data(),
+ underlying_devices.size(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Create a tensor on the parallel device
+ TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
+ TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ std::array components{value_one.get(), value_two.get()};
+ TensorHandlePtr parallel_value = CreatePerDeviceValues(
+ context.get(), components, device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // Run a collective sum, so each component should now be the same.
+ TensorHandlePtr reduced(
+ CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ std::array result_components;
+ ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
+ status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ AssertScalarFloatEq(result_components[0].get(), 3.);
+ AssertScalarFloatEq(result_components[1].get(), 3.);
+}
+
+void RegisterCollectiveMulFunction(TFE_Context* context,
+ const char* function_name, int group_size,
+ TF_Status* status) {
+ std::unique_ptr body(TF_NewGraph(),
+ TF_DeleteGraph);
+ TF_OperationDescription* placeholder_desc =
+ TF_NewOperation(body.get(), "Placeholder", "Placeholder");
+ TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
+ TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ TF_Output x{placeholder_op, 0};
+
+ TF_OperationDescription* reduce_desc =
+ TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
+ TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
+ TF_SetAttrInt(reduce_desc, "group_size", group_size);
+ TF_SetAttrInt(reduce_desc, "group_key", 0);
+ TF_SetAttrInt(reduce_desc, "instance_key", 0);
+
+ const std::string merge_op("Mul");
+ TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
+ merge_op.length());
+ const std::string final_op("Id");
+ TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
+ final_op.length());
+ TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
+ TF_AddInput(reduce_desc, x);
+ TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ TF_Operation* operations[]{placeholder_op, reduce_op};
+ TF_Output y{reduce_op, 0};
+ const char* output_name = "y";
+ std::unique_ptr function(
+ TF_GraphToFunction(
+ /* fn_body */ body.get(), /* fn_name */ function_name,
+ /* append_hash_to_fn_name */ 0, /* num_opers */ 2,
+ /* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
+ /* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
+ /* opts */ nullptr, /* description */ "", /* status */ status),
+ TF_DeleteFunction);
+ if (TF_GetCode(status) != TF_OK) return;
+ TFE_ContextAddFunction(context, function.get(), status);
+}
+
+TEST(PARALLEL_DEVICE, TestFunction) {
+ std::unique_ptr status(
+ TF_NewStatus(), TF_DeleteStatus);
+ std::unique_ptr opts(
+ TFE_NewContextOptions(), TFE_DeleteContextOptions);
+ std::unique_ptr config(
+ TF_CreateConfig(
+ /*xla*/ false,
+ /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
+ 2),
+ TF_DeleteBuffer);
+ TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
+ status.get());
+ std::unique_ptr context(
+ TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
+ std::vector underlying_devices;
+ underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
+ underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
+ tensorflow::eager::RegisterParallelDevice(
+ context.get(), device_name, underlying_devices.data(),
+ underlying_devices.size(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ const char* function_name = "test_reduce_mul";
+ RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
+ TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ std::array components{value_one.get(), value_two.get()};
+ TensorHandlePtr parallel_value = CreatePerDeviceValues(
+ context.get(), components, device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ std::unique_ptr op(
+ TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_OpSetDevice(op.get(), device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_TensorHandle* raw_result_handle;
+ int num_retvals = 1;
+ TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TensorHandlePtr reduced(raw_result_handle);
+
+ std::array result_components;
+ ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
+ status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
+ AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
+
+ std::string first_device = TFE_TensorHandleBackingDeviceName(
+ result_components[0].get(), status.get());
+ ASSERT_EQ(underlying_devices[0], first_device);
+ std::string second_device = TFE_TensorHandleBackingDeviceName(
+ result_components[1].get(), status.get());
+ ASSERT_EQ(underlying_devices[1], second_device);
+}
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 7f1590ff75d..fd4ae10595b 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -16,6 +16,12 @@ cc_library(
deps = ["//tensorflow/core:test_main"],
)
+filegroup(
+ name = "quantize_header",
+ srcs = ["quantize.h"],
+ visibility = ["//visibility:public"],
+)
+
cc_library(
name = "tfcompile_lib",
srcs = [
@@ -27,6 +33,7 @@ cc_library(
"codegen.h",
"compile.h",
"flags.h",
+ "quantize.h",
],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
@@ -37,7 +44,6 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
- "//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index f83cd45f9f3..a2cba5cdf9e 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
-#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
+#include "tensorflow/compiler/aot/quantize.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@@ -46,6 +46,14 @@ limitations under the License.
namespace tensorflow {
namespace tfcompile {
+static llvm::ManagedStatic quantize_xla;
+
+bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
+ if (*quantize_xla) return false;
+ *quantize_xla = fn;
+ return true;
+}
+
namespace {
// Compiles the XLA computation into executable code.
@@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
} else {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
- if (flags.experimental_quantize) {
- TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
+
+ if (flags.experimental_quantize && *quantize_xla) {
+ TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
}
+
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr module,
computation.Snapshot());
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h b/tensorflow/compiler/aot/quantize.h
similarity index 55%
rename from tensorflow/compiler/mlir/lite/quantization/xla/quantize.h
rename to tensorflow/compiler/aot/quantize.h
index 2ec5dbb02ce..add05bd0422 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h
+++ b/tensorflow/compiler/aot/quantize.h
@@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
-#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
+#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
+#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
+
+#include
+#include
+#include
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
-namespace mlir {
-namespace xla_hlo {
+namespace tensorflow {
+namespace tfcompile {
-// Quantizes the model in the computation.
-tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
- xla::XlaComputation* computation);
+using QuantizeXlaFn = std::function;
-} // namespace xla_hlo
-} // namespace mlir
+// Set the static quantization function to the `fn` if it hasn't been set.
+// Return false if the static function has been set.
+bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
-#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
+} // namespace tfcompile
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 5081df28a08..b51749bc332 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp(
arg_shapes.push_back(absl::get(arg.shape));
}
GraphDebugInfo debug_info;
- return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
- compile_options.use_tuple_arg,
- *options.flib_def, debug_info,
- options.shape_representation_fn, result);
+ return CompileGraphToXlaHlo(
+ *graph, {arg_shapes.data(), arg_shapes.size()},
+ options.device_type.type_string(), compile_options.use_tuple_arg,
+ *options.flib_def, debug_info, options.shape_representation_fn, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 63546db1eb0..046ae607438 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -58,6 +58,11 @@ cc_library(
"//tensorflow/python:__subpackages__",
],
deps = [
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:QuantOps",
+ # Link jit lib to link JIT devices required to run
+ # xla-legalize-tf-with-tf2xla pass.
+ "//tensorflow/compiler/jit",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
@@ -65,7 +70,6 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
- "//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
@@ -90,8 +94,6 @@ cc_library(
"//tensorflow/compiler/mlir/xla:xla_lower",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
- "@llvm-project//mlir:Affine",
- "@llvm-project//mlir:QuantOps",
],
)
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 789d06b8ac9..c4042aad12e 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -307,7 +307,7 @@ cc_library(
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
- "transforms/runtime_type_verify.cc",
+ "transforms/runtime_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/while_loop_outline.cc",
@@ -537,6 +537,15 @@ tf_native_cc_binary(
],
)
+tf_native_cc_binary(
+ name = "json_to_flatbuffer",
+ srcs = ["json_to_flatbuffer.cc"],
+ deps = [
+ "//tensorflow/lite/schema:schema_fbs",
+ "@flatbuffers",
+ ],
+)
+
cc_library(
name = "emit_error_reporter",
srcs = [
diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
index 82d058964cb..2ed63fcc794 100644
--- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
+++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
@@ -36,7 +36,8 @@ struct PassConfig {
form_clusters(false),
unfold_batch_matmul(true),
legalize_tf_while(true),
- shape_inference(true) {}
+ shape_inference(true),
+ runtime_verification(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@@ -65,6 +66,8 @@ struct PassConfig {
bool legalize_tf_while;
// Whether to do shape inference.
bool shape_inference;
+ // Whether to do TFLite runtime verification.
+ bool runtime_verification;
};
} // namespace TFL
diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc
index bc894d36e75..f1ed97cb8e7 100644
--- a/tensorflow/compiler/mlir/lite/converter_gen.cc
+++ b/tensorflow/compiler/mlir/lite/converter_gen.cc
@@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
- << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
+ << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
@@ -466,6 +466,25 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
+
+ for (auto &trait : op.getTraits()) {
+ if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
+ continue;
+ }
+ if (trait.getDef().getValueAsString("trait") !=
+ "OpTrait::TFLRuntimeOpTrait") {
+ continue;
+ }
+
+ auto *val = trait.getDef().getValue("tflRuntimePredicate");
+ if (!val) continue;
+
+ mlir::tblgen::Pred pred(dyn_cast(val->getValue()));
+ os << tgfmt(
+ " if (!($0)) {\n "
+ " return ::mlir::LogicalResult::Failure;\n }\n",
+ &verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
+ }
os << " return top.verify();\n}\n";
}
diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h
index 4c9a51c0351..ecc90106ced 100644
--- a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h
+++ b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h
@@ -16,6 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
+// tfl.abs
+template <>
+class TFLiteCostEstimator {
+ public:
+ static double GetCost(mlir::Operation* op) {
+ llvm::errs() << "No defined cost function for op: "
+ << op->getName().getStringRef().str();
+ return 0.0;
+ }
+
+ static bool IsSupported(mlir::Operation* op) { return true; }
+};
+
// tfl.add
template <>
class TFLiteCostEstimator {
@@ -149,6 +162,19 @@ class TFLiteCostEstimator {
static bool IsSupported(mlir::Operation* op) { return true; }
};
+// tfl.log
+template <>
+class TFLiteCostEstimator {
+ public:
+ static double GetCost(mlir::Operation* op) {
+ llvm::errs() << "No defined cost function for op: "
+ << op->getName().getStringRef().str();
+ return 0.0;
+ }
+
+ static bool IsSupported(mlir::Operation* op) { return true; }
+};
+
// tfl.logistic
template <>
class TFLiteCostEstimator {
@@ -240,6 +266,32 @@ class TFLiteCostEstimator {
static bool IsSupported(mlir::Operation* op) { return true; }
};
+// tfl.pow
+template <>
+class TFLiteCostEstimator {
+ public:
+ static double GetCost(mlir::Operation* op) {
+ llvm::errs() << "No defined cost function for op: "
+ << op->getName().getStringRef().str();
+ return 0.0;
+ }
+
+ static bool IsSupported(mlir::Operation* op) { return true; }
+};
+
+// tfl.prelu
+template <>
+class TFLiteCostEstimator {
+ public:
+ static double GetCost(mlir::Operation* op) {
+ llvm::errs() << "No defined cost function for op: "
+ << op->getName().getStringRef().str();
+ return 0.0;
+ }
+
+ static bool IsSupported(mlir::Operation* op) { return true; }
+};
+
// tfl.relu
template <>
class TFLiteCostEstimator {
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
index b20e81aefa9..ccad3cbb79e 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
@@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
- "LogicalResult", "VerifyTflRuntimeTypes",
+ "LogicalResult", "VerifyTflRuntimeConstraints",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
>,
];
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 45efe8f72f7..47a7b32d7e3 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -46,6 +46,30 @@ namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL {
+// Returns true when the given two types have the same shape or broadcastable
+// shape within the given rank. If any given shapes are non-static, this method
+// returns true.
+bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
+ int max_bcast_rank) {
+ // Ignore shape checking on the non-static shapes for model compatibility.
+ auto lhs_shaped_type = lhs.dyn_cast();
+ if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
+ auto rhs_shaped_type = rhs.dyn_cast();
+ if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
+
+ if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
+ return true;
+
+ SmallVector result_shape;
+ if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
+ rhs_shaped_type.getShape(),
+ result_shape)) {
+ return false;
+ }
+ return lhs_shaped_type.getRank() <= max_bcast_rank &&
+ rhs_shaped_type.getRank() <= max_bcast_rank;
+}
+
//===----------------------------------------------------------------------===//
// TensorFlowLiteDialect
//===----------------------------------------------------------------------===//
@@ -316,7 +340,7 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
const int num_elements = result_shape_type.getNumElements();
new_values.reserve(num_elements);
- for (APFloat old_value : dense_elements.getValues()) {
+ for (const APFloat &old_value : dense_elements.getValues()) {
new_values.push_back(calculate(old_value));
}
@@ -844,7 +868,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) {
if (!shape_elements) return nullptr;
SmallVector shape_data;
- for (auto it : shape_elements.getValues()) {
+ for (const auto &it : shape_elements.getValues()) {
shape_data.push_back(it.getSExtValue());
}
result_type =
@@ -1798,7 +1822,7 @@ static LogicalResult Verify(TransposeOp op) {
int index = 0;
llvm::SmallVector axes;
- for (auto axis_int : perm.getValues()) {
+ for (const auto &axis_int : perm.getValues()) {
const int64_t axis = axis_int.getSExtValue();
if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
return op.emitOpError(
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 2bf6ca2ab89..519bd9dbfc0 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -106,6 +106,22 @@ class DerivedShapeAttr : DerivedAttr<"ArrayRef", body>;
class DerivedTFLiteTypeAttr :
DerivedAttr<"tflite::TensorType", body>;
+// TFL Runtime op trait predicate.
+class TFL_RuntimePredOpTrait :
+ GenInternalOpTrait<"TFLRuntimeOpTrait"> {
+ Pred tflRuntimePredicate = pred;
+ string tflRuntimeDescription = desc;
+}
+
+class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
+ int i, int j, int max_bcast_rank> :
+ TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
+ " have the same shape or broadcastable shapes within the rank " #
+ max_bcast_rank,
+ CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
+ "$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
+ ").getType(), " # max_bcast_rank # ")">>;
+
// These additional types/type constraints here are used to decouple the ops
// from runtime support for the ops. Prefer to use these types when defining
// new TF_Ops for uniformity.
@@ -344,7 +360,10 @@ class TFL_ConvOp :
// TFL op definitions.
//===----------------------------------------------------------------------===//
def TFL_AbsOp : TFL_Op<"abs", [
- NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
+ NoSideEffect,
+ SameOperandsAndResultType,
+ NoQuantizableResult,
+ TFL_GpuTargetOp]> {
let summary = "Absolute value operator";
let description = [{
@@ -360,10 +379,9 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1;
}
-def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
- NoSideEffect,
- Commutative,
- TFL_GpuTargetOp]> {
+def TFL_AddOp : TFL_Op<"add", [
+ TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
+ ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
let summary = "Addition operator";
let description = [{
@@ -371,11 +389,11 @@ def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
}];
let arguments = (
- ins AnyTensor:$lhs,
- AnyTensor:$rhs,
+ ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
+ TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
- let results = (outs AnyTensor:$output);
+ let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;
@@ -1527,7 +1545,10 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
}
def TFL_LogOp: TFL_Op<"log", [
- NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
+ NoSideEffect,
+ SameOperandsAndResultType,
+ NoQuantizableResult,
+ TFL_GpuTargetOp]> {
let summary = "Natural logarithm operator";
let description = [{
@@ -2072,7 +2093,10 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
let hasOptions = 1;
}
-def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
+def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
+ NoSideEffect,
+ NoQuantizableResult,
+ TFL_GpuTargetOp]> {
let summary = "Power operator";
let description = [{
@@ -2092,7 +2116,7 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti
let builders = [TFL_BroadcastableBinaryBuilder];
}
-def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> {
+def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Parameterized Relu operator";
let description = [{
diff --git a/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc b/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc
new file mode 100644
index 00000000000..4a4e7a65cd6
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include
+
+#include
+#include
+#include
+#include
+
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "flatbuffers/idl.h" // from @flatbuffers
+#include "flatbuffers/util.h" // from @flatbuffers
+#include "tensorflow/lite/schema/schema_generated.h"
+
+int main(int argc, char** argv) {
+ // load FlatBuffer schema (.fbs) and JSON from disk
+ if (argc < 2) {
+ std::cerr << "Missing input argument. Usage:\n"
+ << argv[0] << " \n\n";
+ return 1;
+ }
+ const char* schema_path = argv[1];
+ const char* json_path = argv[2];
+ std::string schema;
+ std::string json;
+
+ const bool status =
+ flatbuffers::LoadFile(schema_path, /*binary=*/false, &schema) &&
+ flatbuffers::LoadFile(json_path, /*binary=*/false, &json);
+ if (!status) {
+ std::cerr << "couldn't load files!\n";
+ return 1;
+ }
+
+ // parse schema first, so we can use it to parse the data after
+ flatbuffers::Parser parser;
+ const bool schema_parse_result =
+ parser.Parse(schema.c_str()) && parser.Parse(json.c_str());
+ if (!schema_parse_result) {
+ std::cerr << "Parse error.\n";
+ return 1;
+ }
+ const size_t length = parser.builder_.GetSize();
+ const size_t n =
+ std::fwrite(parser.builder_.GetBufferPointer(), 1, length, stdout);
+ if (n != length) {
+ std::cerr << "print to stdout filed.\n";
+ return 1;
+ }
+ return 0;
+}
diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
index 1165561cb71..a07b7b8dd1d 100644
--- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
@@ -88,7 +88,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
- pass_config.shape_inference = false;
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
pass_config, result);
diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
index 681773a7e6b..c338b723a4a 100644
--- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
@@ -16,9 +16,12 @@ limitations under the License.
#include
+#include "llvm/ADT/StringSet.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
+#include "mlir/IR/StandardTypes.h" // from @llvm-project
+#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
@@ -41,6 +44,77 @@ limitations under the License.
namespace tensorflow {
+Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
+ mlir::OwningModuleRef* module) {
+ mlir::FuncOp entry_function = nullptr;
+ for (auto func : module->get().getOps()) {
+ if (auto tf_attrs =
+ func.getAttrOfType("tf.entry_function")) {
+ // TODO(jaesung): There could be multiple entry functions. Let's handle
+ // such cases if there are any needs for that.
+ if (entry_function != nullptr) {
+ return errors::InvalidArgument(
+ "There should be only one tf.entry_function");
+ }
+ entry_function = func;
+ }
+ }
+ if (entry_function == nullptr) {
+ return errors::InvalidArgument("no tf.entry_function found");
+ }
+
+ // Get the list of input Op names from the function attribute.
+ mlir::DictionaryAttr tf_attrs =
+ entry_function.getAttrOfType("tf.entry_function");
+ llvm::SmallVector function_input_names;
+ function_input_names.reserve(model_flags.input_arrays().size());
+ auto input_attr = tf_attrs.get("inputs");
+ if (!input_attr) {
+ return errors::InvalidArgument("no inputs attribute found");
+ }
+ auto input_names = input_attr.cast().getValue();
+ input_names.split(function_input_names, ",");
+ if (function_input_names.size() != model_flags.input_arrays().size()) {
+ return errors::InvalidArgument(
+ "input array size mismatch: got ", function_input_names.size(),
+ ", expected: ", model_flags.input_arrays().size());
+ }
+ llvm::StringSet<> function_input_names_set;
+ function_input_names_set.insert(function_input_names.begin(),
+ function_input_names.end());
+ for (const auto& input_array : model_flags.input_arrays()) {
+ if (function_input_names_set.count(input_array.name()) == 0) {
+ return errors::InvalidArgument("input array name (", input_array.name(),
+ ") does not exist in the given graph");
+ }
+ }
+
+ // Get the list of output Op names from the function attribute.
+ llvm::SmallVector function_output_names;
+ function_output_names.reserve(model_flags.output_arrays().size());
+ auto output_attr = tf_attrs.get("outputs");
+ if (!output_attr) {
+ return errors::InvalidArgument("no outputs attribute found");
+ }
+ auto output_names = output_attr.cast().getValue();
+ output_names.split(function_output_names, ",");
+ if (function_output_names.size() != model_flags.output_arrays().size()) {
+ return errors::InvalidArgument(
+ "output array size mismatch: got ", function_output_names.size(),
+ ", expected: ", model_flags.output_arrays().size());
+ }
+ llvm::StringSet<> function_output_names_set;
+ function_output_names_set.insert(function_output_names.begin(),
+ function_output_names.end());
+ for (const auto& output_array : model_flags.output_arrays()) {
+ if (function_output_names_set.count(output_array) == 0) {
+ return errors::InvalidArgument("output array name (", output_array,
+ ") does not exist in the given graph");
+ }
+ }
+ return Status::OK();
+}
+
Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
string* result) {
@@ -77,11 +151,15 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
model_flags.saved_model_version(), tags,
exported_names, &context));
+ if (!model_flags.input_arrays().empty() ||
+ !model_flags.output_arrays().empty()) {
+ TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
+ }
+
mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
- pass_config.shape_inference = true;
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, result);
diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
index a17cdda2a39..6dd44e666fb 100644
--- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
+++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
@@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
- pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
+ pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD
index a75135cf3b5..a63a1e4b1e5 100644
--- a/tensorflow/compiler/mlir/lite/quantization/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/BUILD
@@ -14,7 +14,10 @@ package(
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
- packages = ["//tensorflow/compiler/mlir/..."],
+ packages = [
+ "//learning/brain/experimental/mlir/quantization/...",
+ "//tensorflow/compiler/mlir/...",
+ ],
)
exports_files([
diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
index 5a5012173e2..9d5aa167ff4 100644
--- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
@@ -55,7 +55,8 @@ namespace quant {
using QuantParamsEntry = QuantizationInfo::QuantParams;
namespace {
-class ImportQuantStatsPass : public FunctionPass {
+class ImportQuantStatsPass
+ : public PassWrapper {
public:
explicit ImportQuantStatsPass(OperationToName op_to_name)
: op_to_name_(op_to_name) {}
@@ -193,7 +194,7 @@ void ImportQuantStatsPass::runOnFunction() {
}
// Creates an instance of the default quant parameters pass.
-std::unique_ptr> CreateImportQuantStatsPass(
+std::unique_ptr> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string &stats_str) {
auto pass = absl::make_unique(op_to_name);
if (pass->ParseQuantStats(stats_str)) return nullptr;
@@ -203,7 +204,7 @@ std::unique_ptr> CreateImportQuantStatsPass(
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
-std::unique_ptr>
+std::unique_ptr>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) {
Location loc = op->getLoc();
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h b/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h
index 2aa5f8e2d0d..bf034d49d4a 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h
@@ -27,13 +27,13 @@ using OperationToName = std::function;
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
-std::unique_ptr> CreateImportQuantStatsPass(
+std::unique_ptr> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string& stats_str);
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
-std::unique_ptr>
+std::unique_ptr>
CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str);
} // namespace quant
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
index 0bd914aa2e7..3d50f280d0f 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
@@ -79,7 +79,7 @@ TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
SmallVector new_scales;
new_scales.reserve(scales.size());
auto scales_iter = scales.begin();
- for (auto f : factor_values) {
+ for (const auto& f : factor_values) {
new_scales.push_back(*(scales_iter++) *
std::fabs(FloatAttr::getValueAsDouble(f)));
}
diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h
index 178daf1b1e0..1ee85a3f4eb 100644
--- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h
+++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h
@@ -25,7 +25,7 @@ namespace mlir {
namespace TF {
// Legalize the tf ops to the quant ops, so the quantization passes can work.
-std::unique_ptr