Merge branch 'master' into tanhsigmoid_16x8
This commit is contained in:
commit
0f047cd174
@ -2,6 +2,10 @@
|
||||
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
|
||||
</div>
|
||||
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
|
||||
|
||||
**`Documentation`** |
|
||||
------------------- |
|
||||
[](https://www.tensorflow.org/api_docs/) |
|
||||
|
@ -223,6 +223,7 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
@ -371,6 +372,22 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_device_testutil",
|
||||
testonly = True,
|
||||
srcs = ["custom_device_testutil.cc"],
|
||||
hdrs = ["custom_device_testutil.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
@ -381,6 +398,7 @@ tf_cc_test(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":custom_device_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
@ -448,6 +466,7 @@ filegroup(
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"*c_api_tfrt*",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
@ -21,6 +21,10 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
@ -31,12 +35,14 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#include "tensorflow/c/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
@ -676,6 +682,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_Context{new tfrt::ContextInterface()};
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
status->status = tensorflow::DeviceFactory::AddDevices(
|
||||
opts->session_options.options, "/job:localhost/replica:0/task:0",
|
||||
@ -1669,6 +1684,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status) {
|
||||
@ -1679,3 +1696,5 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
status->status =
|
||||
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
@ -515,9 +515,11 @@ typedef struct TFE_CustomDevice {
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
|
||||
TFE_CustomDevice device,
|
||||
const char* device_name,
|
||||
void* device_info,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
const char* function_name,
|
||||
|
@ -19,6 +19,10 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
@ -584,9 +588,10 @@ TEST(CAPI, TensorHandleDevices) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
void ExecuteAdd(bool async, bool forward_input) {
|
||||
void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, tfrt);
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -651,7 +656,6 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
float result[100 * 100] = {0};
|
||||
@ -661,12 +665,42 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
for (int i = 0; i < 100 * 100; ++i) {
|
||||
EXPECT_EQ(2.0f, result[i]);
|
||||
}
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
|
||||
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
|
||||
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
|
||||
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
|
||||
TEST(CAPI, ExecuteAdd) {
|
||||
ExecuteAdd(
|
||||
/*async=*/false,
|
||||
/*forward_input*/ false,
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
TEST(CAPI, ExecuteAddAsync) {
|
||||
ExecuteAdd(
|
||||
/*async=*/true,
|
||||
/*forward_input*/ false,
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
TEST(CAPI, ExecuteAddForward) {
|
||||
ExecuteAdd(
|
||||
/*async=*/false,
|
||||
/*forward_input*/ true,
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
TEST(CAPI, ExecuteAddForwardAsync) {
|
||||
ExecuteAdd(
|
||||
/*async=*/true,
|
||||
/*forward_input*/ true,
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
// TODO(b/153349425): Add add forwarding tests for TFRT
|
||||
TEST(CAPI, ExecuteAddTfrt) {
|
||||
ExecuteAdd(
|
||||
/*async=*/false,
|
||||
/*forward_input*/ false,
|
||||
/*tfrt*/ true);
|
||||
}
|
||||
#endif
|
||||
|
||||
void Execute_MatMul_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
@ -20,129 +20,11 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/custom_device_testutil.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
tensorflow::string(
|
||||
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||
TFE_TensorHandle* var_value_unpacked =
|
||||
reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||
->tensor;
|
||||
UnpackTensorHandle(var_value, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||
@ -394,5 +274,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
172
tensorflow/c/eager/custom_device_testutil.cc
Normal file
172
tensorflow/c/eager/custom_device_testutil.cc
Normal file
@ -0,0 +1,172 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
|
||||
TF_Status* status) {
|
||||
return reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
|
||||
->tensor;
|
||||
}
|
||||
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info) {
|
||||
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
|
||||
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device->delete_device = &DeleteLoggingDevice;
|
||||
custom_device->execute = &LoggingDeviceExecute;
|
||||
*device = custom_device;
|
||||
LoggingDevice* logging_device = new LoggingDevice;
|
||||
logging_device->arrived_flag = arrived_flag;
|
||||
logging_device->executed_flag = executed_flag;
|
||||
logging_device->device_name = name;
|
||||
logging_device->underlying_device =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
*device_info = reinterpret_cast<void*>(logging_device);
|
||||
}
|
36
tensorflow/c/eager/custom_device_testutil.h
Normal file
36
tensorflow/c/eager/custom_device_testutil.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status);
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info);
|
||||
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
|
||||
TF_Status* status);
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
38
tensorflow/c/eager/parallel_device/BUILD
Normal file
38
tensorflow/c/eager/parallel_device/BUILD
Normal file
@ -0,0 +1,38 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = ["parallel_device.cc"],
|
||||
hdrs = ["parallel_device.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_test",
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
597
tensorflow/c/eager/parallel_device/parallel_device.cc
Normal file
597
tensorflow/c/eager/parallel_device/parallel_device.cc
Normal file
@ -0,0 +1,597 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
namespace {
|
||||
|
||||
// Functor for making unique_ptrs slightly more ergonomic. Using
|
||||
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
||||
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) const {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class OpDeleter {
|
||||
public:
|
||||
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||
};
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
|
||||
// A representation of the custom device passed in and out of the TFE custom
|
||||
// device APIs, providing context about the parallel device to
|
||||
// ParallelDeviceExecute.
|
||||
class ParallelDevice {
|
||||
public:
|
||||
ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices);
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
||||
// ParallelDevice, or if the individual copies fail.
|
||||
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
//
|
||||
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
|
||||
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
|
||||
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
|
||||
// tensor, but other operations will implicitly broadcast non-parallel input
|
||||
// tensors across the ParallelDevice's component devices.
|
||||
//
|
||||
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
|
||||
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
|
||||
// causes `Execute` to return non-parallel tensors.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK.
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
// Implements the parallel case for `Execute`, where all of the outputs of the
|
||||
// operation are ParallelTensors, and all inputs are either ParallelTensors or
|
||||
// should be implicitly broadcast. This means the operation is not
|
||||
// TPUReplicatedInput or TPUReplicatedOutput.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ExecuteParallelOperation(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
const std::string& device_name() const { return device_name_; }
|
||||
|
||||
private:
|
||||
// The name of the parallel device
|
||||
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
|
||||
const std::string device_name_;
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
};
|
||||
|
||||
// The internal representation of a TFE_TensorHandle placed on a
|
||||
// ParallelDevice. Contains a tuple of tensors, one on each of the
|
||||
// `underlying_devices_` of the ParallelDevice.
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
|
||||
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
|
||||
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
|
||||
std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
const std::vector<TensorHandlePtr> tensors_;
|
||||
const std::vector<int64_t> shape_;
|
||||
const TF_DataType dtype_;
|
||||
};
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices)
|
||||
: device_name_(name),
|
||||
underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
|
||||
if (device_name_ == current_device) {
|
||||
std::string message(absl::StrCat(
|
||||
"Tried to copy a TensorHandle to its existing device: ", device_name_));
|
||||
TF_SetStatus(status, TF_INTERNAL, message.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (const std::string& underlying_device_name : underlying_devices_) {
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, underlying_device_name.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(t);
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
|
||||
// TODO(allenl): We should remove "TPU" from these op names at the very least,
|
||||
// or consider other ways of packing/unpacking parallel tensors.
|
||||
if (operation_name == std::string("TPUReplicatedInput")) {
|
||||
// Special-cased operation for packing per-device tensors into one parallel
|
||||
// tensor.
|
||||
if (inputs.size() != underlying_devices_.size()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
|
||||
inputs.size()));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
|
||||
std::string message(absl::StrCat(
|
||||
"Expected all inputs to TPUReplicatedInput to be non-parallel "
|
||||
"TensorHandles. The input ",
|
||||
i,
|
||||
" was a parallel tensor (already "
|
||||
"placed on the parallel device)."));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
components.emplace_back(TFE_TensorHandleCopySharingTensor(
|
||||
absl::get<TFE_TensorHandle*>(inputs[i]), status));
|
||||
}
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
} else if (operation_name == std::string("TPUReplicatedOutput")) {
|
||||
// Special-cased operation for un-packing one parallel tensor into
|
||||
// per-device tensors.
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
if (expected_outputs != underlying_devices_.size()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(),
|
||||
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"Expected the input to "
|
||||
"TPUReplicatedOutput to be a parallel tensor (placed on the "
|
||||
"parallel device).");
|
||||
return result;
|
||||
}
|
||||
ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
|
||||
std::vector<MaybeParallelTensorOwned> outputs;
|
||||
outputs.reserve(t->num_tensors());
|
||||
for (int i = 0; i < t->num_tensors(); ++i) {
|
||||
TensorHandlePtr this_output(
|
||||
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
|
||||
outputs.emplace_back(std::move(this_output));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
ExecuteParallelOperation(context, std::move(inputs), operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
std::move(maybe_parallel_results.value()));
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(parallel_results.size());
|
||||
for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
|
||||
result_content.push_back(
|
||||
MaybeParallelTensorOwned(std::move(parallel_result)));
|
||||
}
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::ExecuteParallelOperation(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
per_device_outputs.reserve(first_op_output_count);
|
||||
for (int i = 0; i < first_op_output_count; ++i) {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"Components of a ParallelTensor must currently all have "
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
}
|
||||
|
||||
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
|
||||
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
|
||||
// reference counts drop to zero.
|
||||
void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<ParallelTensor*>(data);
|
||||
}
|
||||
|
||||
TensorHandlePtr ParallelTensor::AsTensorHandle(
|
||||
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status) {
|
||||
// The resulting TensorHandle owns an opaque pointer to "device memory", which
|
||||
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
|
||||
// deleted, it will call ParallelTensorDeallocator to free the struct.
|
||||
ParallelTensor* t_released = t.release();
|
||||
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, t_released->device_.device_name().c_str(), t_released->dtype_,
|
||||
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
|
||||
&ParallelTensorDeallocator, nullptr, status));
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
|
||||
// registration.
|
||||
//
|
||||
// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
|
||||
// a ParallelTensor with one copy of `tensor` for each device in the
|
||||
// ParallelDevice.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
dev->CopyToParallelDevice(context, tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
|
||||
status)
|
||||
.release();
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::copy_tensor_from_device in the parallel device
|
||||
// registration.
|
||||
//
|
||||
// Currently this is an error, and un-packing ParallelTensors must be performed
|
||||
// explicitly by running a TPUReplicatedOutput operation on the parallel device.
|
||||
//
|
||||
// TODO(allenl): There are some use-cases that are only supported by copying to
|
||||
// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
|
||||
// need to return something here or address these use-cases one by one.
|
||||
TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a parallel device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::execute in the parallel device registration.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* status,
|
||||
void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||
typed_inputs.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const char* tensor_handle_device =
|
||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (dev->device_name() == tensor_handle_device) {
|
||||
// We assume that any tensors already placed on this device are
|
||||
// ParallelTensors.
|
||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||
TFE_TensorHandleDevicePointer(inputs[i], status)));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
typed_inputs.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
|
||||
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
|
||||
*num_outputs, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!maybe_typed_outputs.has_value()) {
|
||||
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<MaybeParallelTensorOwned> typed_outputs(
|
||||
std::move(maybe_typed_outputs.value()));
|
||||
|
||||
if (typed_outputs.size() > *num_outputs) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"The allocated output buffer was too small.");
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < typed_outputs.size(); ++i) {
|
||||
MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
|
||||
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
|
||||
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
|
||||
} else {
|
||||
outputs[i] = ParallelTensor::AsTensorHandle(
|
||||
context,
|
||||
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
|
||||
typed_output)),
|
||||
status)
|
||||
.release();
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
}
|
||||
*num_outputs = typed_outputs.size();
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::delete_device in the parallel device registration.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void DeleteParallelDevice(void* device_info) {
|
||||
delete reinterpret_cast<ParallelDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
custom_device.delete_device = &DeleteParallelDevice;
|
||||
custom_device.execute = &ParallelDeviceExecute;
|
||||
std::vector<std::string> underlying_devices_vector;
|
||||
underlying_devices_vector.reserve(num_underlying_devices);
|
||||
for (int device_index = 0; device_index < num_underlying_devices;
|
||||
++device_index) {
|
||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||
}
|
||||
ParallelDevice* d =
|
||||
new ParallelDevice(device_name, underlying_devices_vector);
|
||||
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
62
tensorflow/c/eager/parallel_device/parallel_device.h
Normal file
62
tensorflow/c/eager/parallel_device/parallel_device.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// Register a parallel device named `device_name` which forwards operations to
|
||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||
// on each underlying device.
|
||||
//
|
||||
// For example if `device_name` is
|
||||
// "/job:localhost/replica:0/task:0/device:CUSTOM:0"
|
||||
// and `underlying_devices` is
|
||||
// {"/job:localhost/replica:0/task:0/device:GPU:0",
|
||||
// "/job:localhost/replica:0/task:0/device:GPU:1"}
|
||||
// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1.
|
||||
//
|
||||
// Implicit copies onto `device_name` are allowed, replicating the value once
|
||||
// per device in `underlying_devices`. Implicit copies off of the device throw
|
||||
// an error.
|
||||
//
|
||||
// All component tensors must have the same dtype. Currently they must also have
|
||||
// the same shape, although this requirement may be relaxed in the future.
|
||||
//
|
||||
// `device_name` must not name an existing physical or custom device (see
|
||||
// the documentation for TFE_RegisterCustomDevice for more information).
|
||||
//
|
||||
// Tensors may be copied on or off the device explicitly using
|
||||
// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with
|
||||
// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the
|
||||
// parallel device creates a parallel tensor `x` with `a` on the first of
|
||||
// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked =
|
||||
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
|
||||
// into its components.
|
||||
//
|
||||
// `context` owns the parallel device. `underlying_devices` must stay valid
|
||||
// while the parallel device is in use.
|
||||
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
|
||||
const char** underlying_devices,
|
||||
int num_underlying_devices, TF_Status* status);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
917
tensorflow/c/eager/parallel_device/parallel_device_test.cc
Normal file
917
tensorflow/c/eager/parallel_device/parallel_device_test.cc
Normal file
@ -0,0 +1,917 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||
// integration testing rather than purely testing the parallel device. They
|
||||
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||
// another option.
|
||||
|
||||
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
||||
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
||||
// template argument requires passing a function pointer to
|
||||
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
// A helper for performing common operations on variables. A much more
|
||||
// restricted stand-in for tf.Variable in Python.
|
||||
class Variable {
|
||||
public:
|
||||
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
||||
// indication of the dtype of the variable's value.
|
||||
//
|
||||
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
||||
// separate static method which returns a status.
|
||||
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
||||
: handle_(handle), type_(type) {}
|
||||
|
||||
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
||||
// object.
|
||||
static Variable* Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status);
|
||||
// Dereferences the backing buffer for the variable. Note that since this can
|
||||
// fail (it runs operations), it must be called explicitly and the resulting
|
||||
// `status` checked.
|
||||
void Destroy(TFE_Context* context, TF_Status* status);
|
||||
|
||||
// Reads from the variable.
|
||||
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
||||
// Assigns a new value to the variable.
|
||||
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
||||
// Adds `value` to the existing value of the variable.
|
||||
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status);
|
||||
|
||||
private:
|
||||
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
||||
// AssignSub, ...).
|
||||
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status);
|
||||
|
||||
// The a handle for the resource-dtype tensor pointing to the variable's
|
||||
// buffer.
|
||||
TFE_TensorHandle* handle_;
|
||||
// The dtype of the variable's buffer (input dtype for assignments, output
|
||||
// dtype of read operations).
|
||||
TF_DataType type_;
|
||||
};
|
||||
|
||||
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
// Use the special GUID for no buffer sharing
|
||||
//
|
||||
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
||||
// only reasonable way to make variables with no aliasing using the eager C
|
||||
// API.
|
||||
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
||||
no_sharing.length());
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return new Variable(var_handle, type);
|
||||
}
|
||||
|
||||
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
||||
// Free the backing buffer for the variable.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
// Delete the variable handle itself.
|
||||
TFE_DeleteTensorHandle(handle_);
|
||||
}
|
||||
|
||||
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(var_value);
|
||||
}
|
||||
|
||||
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), value, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
|
||||
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
||||
// deleted.
|
||||
static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
||||
const int num_bytes = sizeof(float);
|
||||
float* values = new float[1];
|
||||
values[0] = v;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status) {
|
||||
const int num_bytes = v.size() * sizeof(float);
|
||||
float* values = new float[v.size()];
|
||||
memcpy(values, v.data(), num_bytes);
|
||||
int64_t dims = v.size();
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
||||
&FloatDeallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles[num_replicas];
|
||||
int num_retvals = num_replicas;
|
||||
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
(*components)[i].reset(result_handles[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
TFE_OpAddInput(op.get(), components[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), second, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), first_device, status);
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(expected_value,
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device) {
|
||||
// Register the custom device
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context, device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
// device.
|
||||
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
||||
to_delete->Destroy(context, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
delete to_delete;
|
||||
};
|
||||
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
||||
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
||||
status.get()),
|
||||
variable_deleter);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Assign an initial value to the variable, implicitly mirroring it to each
|
||||
// component device.
|
||||
{
|
||||
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
variable->Assign(context, initial_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read from the variable and verify that we have a parallel tensor.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 20.);
|
||||
AssertScalarFloatEq(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Add a parallel tensor with different values on each device to the variable.
|
||||
{
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value =
|
||||
CreatePerDeviceValues(context, components, device_name, status.get());
|
||||
variable->AssignAdd(context, combined_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read the variable and verify that each component has the right modified
|
||||
// value.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 23.);
|
||||
AssertScalarFloatEq(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Skip the test if no TPU is available.
|
||||
std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
|
||||
TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool has_tpu = false;
|
||||
for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
|
||||
++device_index) {
|
||||
std::string device_type =
|
||||
TF_DeviceListType(devices.get(), device_index, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
if (device_type == "TPU") {
|
||||
has_tpu = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_tpu) {
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:TPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:TPU:0");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
underlying_devices.push_back(first_device_name);
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
underlying_devices.push_back(second_device_name);
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Copying on to a parallel device is OK.
|
||||
TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
|
||||
cpu_value.get(), context.get(), device_name, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
const char* backing_device =
|
||||
TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(std::string(device_name), backing_device);
|
||||
|
||||
// Un-pack the parallel tensor to verify that the copy was successful.
|
||||
{
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context.get(), device_value.get(), &components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// The value of the original tensor is replicated on each device.
|
||||
AssertScalarFloatEq(components[0].get(), 3.);
|
||||
AssertScalarFloatEq(components[1].get(), 3.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Copies off of parallel devices must be explicit.
|
||||
TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
|
||||
device_value.get(), context.get(), first_device_name, status.get()));
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestDifferentShapes) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create two vectors with different lengths
|
||||
std::vector<float> size_two_value{1., 2.};
|
||||
std::vector<float> size_three_value{1., 2., 3.};
|
||||
TensorHandlePtr size_two(
|
||||
VectorFloatTensorHandle(size_two_value, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr size_three(
|
||||
VectorFloatTensorHandle(size_three_value, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Try to combine these values into a single parallel tensor.
|
||||
std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
3),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a parallel device with two CPUs
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> first_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), first_device_name, first_underlying_devices.data(),
|
||||
first_underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a second parallel device with the first parallel device and one
|
||||
// additional CPU.
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
std::vector<const char*> second_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2"};
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), second_device_name, second_underlying_devices.data(),
|
||||
second_underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the first parallel device
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr first_combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, first_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Nest the first parallel tensor into a second
|
||||
TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
components[0] = first_combined_value.get();
|
||||
components[1] = value_three.get();
|
||||
TensorHandlePtr second_combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, second_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr multiply_result(Multiply(context.get(),
|
||||
second_combined_value.get(),
|
||||
negative_one.get(), status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Un-pack the parallel tensor to verify that the operation was
|
||||
// successful. The resulting structure should be:
|
||||
// second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
|
||||
std::array<TensorHandlePtr, 2> second_components;
|
||||
ExtractPerDeviceValues(context.get(), multiply_result.get(),
|
||||
&second_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
second_components[0].get(), status.get());
|
||||
ASSERT_EQ(second_underlying_devices[0], first_device);
|
||||
std::string second_device = TFE_TensorHandleBackingDeviceName(
|
||||
second_components[1].get(), status.get());
|
||||
ASSERT_EQ(second_underlying_devices[1], second_device);
|
||||
|
||||
// Un-pack the first parallel device's tensor too
|
||||
std::array<TensorHandlePtr, 2> first_components;
|
||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||
&first_components, status.get());
|
||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
||||
|
||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||
status.get());
|
||||
ASSERT_EQ(first_underlying_devices[0], first_device);
|
||||
second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
|
||||
status.get());
|
||||
ASSERT_EQ(first_underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestInvalidPacking) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
{
|
||||
// Try to pack two TensorHandles onto a parallel device with a single
|
||||
// component.
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to extract the wrong number of components from a parallel tensor
|
||||
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), correct_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TensorHandlePtr, 2> incorrect_components;
|
||||
ExtractPerDeviceValues(context.get(), combined_value.get(),
|
||||
&incorrect_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to pass a ParallelTensor to TPUReplicatedInput
|
||||
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), correct_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
|
||||
TensorHandlePtr recombined_value = CreatePerDeviceValues(
|
||||
context.get(), incorrect_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to pass a non-parallel tensor to TPUReplicatedOutput
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
|
||||
TFE_DeleteOp);
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
|
||||
TFE_OpAddInput(op.get(), value_one.get(), status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
}
|
||||
|
||||
TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
|
||||
int group_size, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
|
||||
TFE_OpSetAttrInt(op.get(), "group_size", group_size);
|
||||
TFE_OpSetAttrInt(op.get(), "group_key", 0);
|
||||
TFE_OpSetAttrInt(op.get(), "instance_key", 0);
|
||||
const std::string merge_op("Add");
|
||||
TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
|
||||
merge_op.length());
|
||||
const std::string final_op("Id");
|
||||
TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
|
||||
final_op.length());
|
||||
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
|
||||
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the parallel device
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr parallel_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Run a collective sum, so each component should now be the same.
|
||||
TensorHandlePtr reduced(
|
||||
CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TensorHandlePtr, 2> result_components;
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
const char* function_name, int group_size,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
|
||||
TF_DeleteGraph);
|
||||
TF_OperationDescription* placeholder_desc =
|
||||
TF_NewOperation(body.get(), "Placeholder", "Placeholder");
|
||||
TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
|
||||
TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_Output x{placeholder_op, 0};
|
||||
|
||||
TF_OperationDescription* reduce_desc =
|
||||
TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
|
||||
TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
|
||||
TF_SetAttrInt(reduce_desc, "group_size", group_size);
|
||||
TF_SetAttrInt(reduce_desc, "group_key", 0);
|
||||
TF_SetAttrInt(reduce_desc, "instance_key", 0);
|
||||
|
||||
const std::string merge_op("Mul");
|
||||
TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
|
||||
merge_op.length());
|
||||
const std::string final_op("Id");
|
||||
TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
|
||||
final_op.length());
|
||||
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
|
||||
TF_AddInput(reduce_desc, x);
|
||||
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_Operation* operations[]{placeholder_op, reduce_op};
|
||||
TF_Output y{reduce_op, 0};
|
||||
const char* output_name = "y";
|
||||
std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
|
||||
TF_GraphToFunction(
|
||||
/* fn_body */ body.get(), /* fn_name */ function_name,
|
||||
/* append_hash_to_fn_name */ 0, /* num_opers */ 2,
|
||||
/* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
|
||||
/* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
|
||||
/* opts */ nullptr, /* description */ "", /* status */ status),
|
||||
TF_DeleteFunction);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_ContextAddFunction(context, function.get(), status);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::vector<const char*> underlying_devices;
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
tensorflow::eager::RegisterParallelDevice(
|
||||
context.get(), device_name, underlying_devices.data(),
|
||||
underlying_devices.size(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* function_name = "test_reduce_mul";
|
||||
RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr parallel_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* raw_result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr reduced(raw_result_handle);
|
||||
|
||||
std::array<TensorHandlePtr, 2> result_components;
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
||||
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
@ -16,6 +16,12 @@ cc_library(
|
||||
deps = ["//tensorflow/core:test_main"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "quantize_header",
|
||||
srcs = ["quantize.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfcompile_lib",
|
||||
srcs = [
|
||||
@ -27,6 +33,7 @@ cc_library(
|
||||
"codegen.h",
|
||||
"compile.h",
|
||||
"flags.h",
|
||||
"quantize.h",
|
||||
],
|
||||
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
@ -37,7 +44,6 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
#include "tensorflow/compiler/aot/quantize.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -46,6 +46,14 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
|
||||
|
||||
bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
|
||||
if (*quantize_xla) return false;
|
||||
*quantize_xla = fn;
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Compiles the XLA computation into executable code.
|
||||
@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
} else {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
if (flags.experimental_quantize) {
|
||||
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
|
||||
|
||||
if (flags.experimental_quantize && *quantize_xla) {
|
||||
TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
|
||||
}
|
||||
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
computation.Snapshot());
|
||||
|
@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation);
|
||||
using QuantizeXlaFn = std::function<Status(const tf2xla::Config& config,
|
||||
xla::XlaComputation* computation)>;
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
// Set the static quantization function to the `fn` if it hasn't been set.
|
||||
// Return false if the static function has been set.
|
||||
bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
|
||||
compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info,
|
||||
options.shape_representation_fn, result);
|
||||
return CompileGraphToXlaHlo(
|
||||
*graph, {arg_shapes.data(), arg_shapes.size()},
|
||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
|
@ -58,6 +58,11 @@ cc_library(
|
||||
"//tensorflow/python:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
# Link jit lib to link JIT devices required to run
|
||||
# xla-legalize-tf-with-tf2xla pass.
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
|
||||
@ -65,7 +70,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
@ -90,8 +94,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -307,7 +307,7 @@ cc_library(
|
||||
"transforms/optimize_functional_ops.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/runtime_type_verify.cc",
|
||||
"transforms/runtime_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
@ -537,6 +537,15 @@ tf_native_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "json_to_flatbuffer",
|
||||
srcs = ["json_to_flatbuffer.cc"],
|
||||
deps = [
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "emit_error_reporter",
|
||||
srcs = [
|
||||
|
@ -36,7 +36,8 @@ struct PassConfig {
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
shape_inference(true) {}
|
||||
shape_inference(true),
|
||||
runtime_verification(true) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
@ -65,6 +66,8 @@ struct PassConfig {
|
||||
bool legalize_tf_while;
|
||||
// Whether to do shape inference.
|
||||
bool shape_inference;
|
||||
// Whether to do TFLite runtime verification.
|
||||
bool runtime_verification;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
|
||||
"failure_on_operand_type_mismatch) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
@ -466,6 +466,25 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
"operand");
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||
"result");
|
||||
|
||||
for (auto &trait : op.getTraits()) {
|
||||
if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
|
||||
continue;
|
||||
}
|
||||
if (trait.getDef().getValueAsString("trait") !=
|
||||
"OpTrait::TFLRuntimeOpTrait") {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *val = trait.getDef().getValue("tflRuntimePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
os << tgfmt(
|
||||
" if (!($0)) {\n "
|
||||
" return ::mlir::LogicalResult::Failure;\n }\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
|
||||
}
|
||||
os << " return top.verify();\n}\n";
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,19 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
|
||||
|
||||
// tfl.abs
|
||||
template <>
|
||||
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.add
|
||||
template <>
|
||||
class TFLiteCostEstimator<AddOp, hardware::GPU> {
|
||||
@ -149,6 +162,19 @@ class TFLiteCostEstimator<HardSwishOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.log
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.logistic
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
|
||||
@ -240,6 +266,32 @@ class TFLiteCostEstimator<PadOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.pow
|
||||
template <>
|
||||
class TFLiteCostEstimator<PowOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.prelu
|
||||
template <>
|
||||
class TFLiteCostEstimator<PReluOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.relu
|
||||
template <>
|
||||
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
|
||||
|
@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeTypes",
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints",
|
||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||
>,
|
||||
];
|
||||
|
@ -46,6 +46,30 @@ namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||
namespace TFL {
|
||||
|
||||
// Returns true when the given two types have the same shape or broadcastable
|
||||
// shape within the given rank. If any given shapes are non-static, this method
|
||||
// returns true.
|
||||
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
|
||||
int max_bcast_rank) {
|
||||
// Ignore shape checking on the non-static shapes for model compatibility.
|
||||
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
|
||||
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
|
||||
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
|
||||
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
|
||||
|
||||
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
|
||||
return true;
|
||||
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
|
||||
rhs_shaped_type.getShape(),
|
||||
result_shape)) {
|
||||
return false;
|
||||
}
|
||||
return lhs_shaped_type.getRank() <= max_bcast_rank &&
|
||||
rhs_shaped_type.getRank() <= max_bcast_rank;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlowLiteDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -316,7 +340,7 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
|
||||
const int num_elements = result_shape_type.getNumElements();
|
||||
new_values.reserve(num_elements);
|
||||
|
||||
for (APFloat old_value : dense_elements.getValues<APFloat>()) {
|
||||
for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
|
||||
new_values.push_back(calculate(old_value));
|
||||
}
|
||||
|
||||
@ -844,7 +868,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!shape_elements) return nullptr;
|
||||
|
||||
SmallVector<int64_t, 4> shape_data;
|
||||
for (auto it : shape_elements.getValues<APInt>()) {
|
||||
for (const auto &it : shape_elements.getValues<APInt>()) {
|
||||
shape_data.push_back(it.getSExtValue());
|
||||
}
|
||||
result_type =
|
||||
@ -1798,7 +1822,7 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
|
||||
int index = 0;
|
||||
llvm::SmallVector<int64_t, 4> axes;
|
||||
for (auto axis_int : perm.getValues<APInt>()) {
|
||||
for (const auto &axis_int : perm.getValues<APInt>()) {
|
||||
const int64_t axis = axis_int.getSExtValue();
|
||||
if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
|
||||
return op.emitOpError(
|
||||
|
@ -106,6 +106,22 @@ class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
|
||||
class DerivedTFLiteTypeAttr<code body> :
|
||||
DerivedAttr<"tflite::TensorType", body>;
|
||||
|
||||
// TFL Runtime op trait predicate.
|
||||
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
|
||||
GenInternalOpTrait<"TFLRuntimeOpTrait"> {
|
||||
Pred tflRuntimePredicate = pred;
|
||||
string tflRuntimeDescription = desc;
|
||||
}
|
||||
|
||||
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
|
||||
int i, int j, int max_bcast_rank> :
|
||||
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
|
||||
" have the same shape or broadcastable shapes within the rank " #
|
||||
max_bcast_rank,
|
||||
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
|
||||
").getType(), " # max_bcast_rank # ")">>;
|
||||
|
||||
// These additional types/type constraints here are used to decouple the ops
|
||||
// from runtime support for the ops. Prefer to use these types when defining
|
||||
// new TF_Ops for uniformity.
|
||||
@ -344,7 +360,10 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
// TFL op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def TFL_AbsOp : TFL_Op<"abs", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Absolute value operator";
|
||||
|
||||
let description = [{
|
||||
@ -360,10 +379,9 @@ an output element, this operation computes \\(y = |x|\\).
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
TFL_GpuTargetOp]> {
|
||||
def TFL_AddOp : TFL_Op<"add", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||
ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
|
||||
let summary = "Addition operator";
|
||||
|
||||
let description = [{
|
||||
@ -371,11 +389,11 @@ def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs,
|
||||
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
@ -1527,7 +1545,10 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
}
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [
|
||||
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Natural logarithm operator";
|
||||
|
||||
let description = [{
|
||||
@ -2072,7 +2093,10 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Power operator";
|
||||
|
||||
let description = [{
|
||||
@ -2092,7 +2116,7 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> {
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
let summary = "Parameterized Relu operator";
|
||||
|
||||
let description = [{
|
||||
|
63
tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc
Normal file
63
tensorflow/compiler/mlir/lite/json_to_flatbuffer.cc
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "flatbuffers/idl.h" // from @flatbuffers
|
||||
#include "flatbuffers/util.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// load FlatBuffer schema (.fbs) and JSON from disk
|
||||
if (argc < 2) {
|
||||
std::cerr << "Missing input argument. Usage:\n"
|
||||
<< argv[0] << " <filename or - for stdin>\n\n";
|
||||
return 1;
|
||||
}
|
||||
const char* schema_path = argv[1];
|
||||
const char* json_path = argv[2];
|
||||
std::string schema;
|
||||
std::string json;
|
||||
|
||||
const bool status =
|
||||
flatbuffers::LoadFile(schema_path, /*binary=*/false, &schema) &&
|
||||
flatbuffers::LoadFile(json_path, /*binary=*/false, &json);
|
||||
if (!status) {
|
||||
std::cerr << "couldn't load files!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// parse schema first, so we can use it to parse the data after
|
||||
flatbuffers::Parser parser;
|
||||
const bool schema_parse_result =
|
||||
parser.Parse(schema.c_str()) && parser.Parse(json.c_str());
|
||||
if (!schema_parse_result) {
|
||||
std::cerr << "Parse error.\n";
|
||||
return 1;
|
||||
}
|
||||
const size_t length = parser.builder_.GetSize();
|
||||
const size_t n =
|
||||
std::fwrite(parser.builder_.GetBufferPointer(), 1, length, stdout);
|
||||
if (n != length) {
|
||||
std::cerr << "print to stdout filed.\n";
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
@ -88,7 +88,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
pass_config.shape_inference = false;
|
||||
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
pass_config, result);
|
||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
|
||||
@ -41,6 +44,77 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
|
||||
mlir::OwningModuleRef* module) {
|
||||
mlir::FuncOp entry_function = nullptr;
|
||||
for (auto func : module->get().getOps<mlir::FuncOp>()) {
|
||||
if (auto tf_attrs =
|
||||
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
|
||||
// TODO(jaesung): There could be multiple entry functions. Let's handle
|
||||
// such cases if there are any needs for that.
|
||||
if (entry_function != nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"There should be only one tf.entry_function");
|
||||
}
|
||||
entry_function = func;
|
||||
}
|
||||
}
|
||||
if (entry_function == nullptr) {
|
||||
return errors::InvalidArgument("no tf.entry_function found");
|
||||
}
|
||||
|
||||
// Get the list of input Op names from the function attribute.
|
||||
mlir::DictionaryAttr tf_attrs =
|
||||
entry_function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||
llvm::SmallVector<llvm::StringRef, 4> function_input_names;
|
||||
function_input_names.reserve(model_flags.input_arrays().size());
|
||||
auto input_attr = tf_attrs.get("inputs");
|
||||
if (!input_attr) {
|
||||
return errors::InvalidArgument("no inputs attribute found");
|
||||
}
|
||||
auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
|
||||
input_names.split(function_input_names, ",");
|
||||
if (function_input_names.size() != model_flags.input_arrays().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"input array size mismatch: got ", function_input_names.size(),
|
||||
", expected: ", model_flags.input_arrays().size());
|
||||
}
|
||||
llvm::StringSet<> function_input_names_set;
|
||||
function_input_names_set.insert(function_input_names.begin(),
|
||||
function_input_names.end());
|
||||
for (const auto& input_array : model_flags.input_arrays()) {
|
||||
if (function_input_names_set.count(input_array.name()) == 0) {
|
||||
return errors::InvalidArgument("input array name (", input_array.name(),
|
||||
") does not exist in the given graph");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the list of output Op names from the function attribute.
|
||||
llvm::SmallVector<llvm::StringRef, 4> function_output_names;
|
||||
function_output_names.reserve(model_flags.output_arrays().size());
|
||||
auto output_attr = tf_attrs.get("outputs");
|
||||
if (!output_attr) {
|
||||
return errors::InvalidArgument("no outputs attribute found");
|
||||
}
|
||||
auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
|
||||
output_names.split(function_output_names, ",");
|
||||
if (function_output_names.size() != model_flags.output_arrays().size()) {
|
||||
return errors::InvalidArgument(
|
||||
"output array size mismatch: got ", function_output_names.size(),
|
||||
", expected: ", model_flags.output_arrays().size());
|
||||
}
|
||||
llvm::StringSet<> function_output_names_set;
|
||||
function_output_names_set.insert(function_output_names.begin(),
|
||||
function_output_names.end());
|
||||
for (const auto& output_array : model_flags.output_arrays()) {
|
||||
if (function_output_names_set.count(output_array) == 0) {
|
||||
return errors::InvalidArgument("output array name (", output_array,
|
||||
") does not exist in the given graph");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
string* result) {
|
||||
@ -77,11 +151,15 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
model_flags.saved_model_version(), tags,
|
||||
exported_names, &context));
|
||||
|
||||
if (!model_flags.input_arrays().empty() ||
|
||||
!model_flags.output_arrays().empty()) {
|
||||
TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
|
||||
}
|
||||
|
||||
mlir::TFL::PassConfig pass_config(quant_specs);
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
pass_config.shape_inference = true;
|
||||
|
||||
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, result);
|
||||
|
@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
|
@ -14,7 +14,10 @@ package(
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = ["//tensorflow/compiler/mlir/..."],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/quantization/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files([
|
||||
|
@ -55,7 +55,8 @@ namespace quant {
|
||||
using QuantParamsEntry = QuantizationInfo::QuantParams;
|
||||
|
||||
namespace {
|
||||
class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
||||
class ImportQuantStatsPass
|
||||
: public PassWrapper<ImportQuantStatsPass, FunctionPass> {
|
||||
public:
|
||||
explicit ImportQuantStatsPass(OperationToName op_to_name)
|
||||
: op_to_name_(op_to_name) {}
|
||||
@ -193,7 +194,7 @@ void ImportQuantStatsPass::runOnFunction() {
|
||||
}
|
||||
|
||||
// Creates an instance of the default quant parameters pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
|
||||
OperationToName op_to_name, const std::string &stats_str) {
|
||||
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
|
||||
if (pass->ParseQuantStats(stats_str)) return nullptr;
|
||||
@ -203,7 +204,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
// Creates an instance pass to import quantization stats to the operations in
|
||||
// the function. A custom method to get the name from the op is used because
|
||||
// different dialect ops might have different ways to assign the name.
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
|
||||
auto get_name_func = [](Operation *op) {
|
||||
Location loc = op->getLoc();
|
||||
|
@ -27,13 +27,13 @@ using OperationToName = std::function<llvm::StringRef(Operation* op)>;
|
||||
// Creates an instance pass to import quantization stats to the operations in
|
||||
// the function. A custom method to get the name from the op is used because
|
||||
// different dialect ops might have different ways to assign the name.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
|
||||
OperationToName op_to_name, const std::string& stats_str);
|
||||
|
||||
// Creates an instance pass to import quantization stats to the operations in
|
||||
// the function. A custom method to get the name from the op is used because
|
||||
// different dialect ops might have different ways to assign the name.
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str);
|
||||
|
||||
} // namespace quant
|
||||
|
@ -79,7 +79,7 @@ TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
|
||||
SmallVector<double, 4> new_scales;
|
||||
new_scales.reserve(scales.size());
|
||||
auto scales_iter = scales.begin();
|
||||
for (auto f : factor_values) {
|
||||
for (const auto& f : factor_values) {
|
||||
new_scales.push_back(*(scales_iter++) *
|
||||
std::fabs(FloatAttr::getValueAsDouble(f)));
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// Legalize the tf ops to the quant ops, so the quantization passes can work.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass();
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -27,7 +27,7 @@ namespace TF {
|
||||
namespace {
|
||||
|
||||
// Legalize TF quantization emulation ops to that in Quant ops dialect.
|
||||
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
|
||||
struct LegalizeTFToQuant : public PassWrapper<LegalizeTFToQuant, FunctionPass> {
|
||||
explicit LegalizeTFToQuant() = default;
|
||||
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
|
||||
|
||||
@ -151,7 +151,7 @@ void LegalizeTFToQuant::runOnFunction() {
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass() {
|
||||
return std::make_unique<LegalizeTFToQuant>();
|
||||
}
|
||||
|
||||
|
@ -1,112 +0,0 @@
|
||||
load(
|
||||
"//third_party/mlir:tblgen.bzl",
|
||||
"gentbl",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/aot/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_xla_quantization_passes",
|
||||
srcs = [
|
||||
"cpu_kernel_fusion.cc",
|
||||
"generated_cpu_kernel_fusion.inc",
|
||||
"materialize.cc",
|
||||
"op_quant_spec.inc",
|
||||
"propagate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
":cpu_device_target",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_device_target",
|
||||
srcs = [
|
||||
"cpu_device_target.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"cpu_device_target.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:device_target",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantize",
|
||||
srcs = [
|
||||
"quantize.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantize.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "cpu_kernel_fusion_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"generated_cpu_kernel_fusion.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "cpu_kernel_fusion.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
)
|
@ -1,67 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace ph = std::placeholders;
|
||||
|
||||
CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
|
||||
RegisterKernel("generic.concat", {qi8_, qi8_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
|
||||
// TODO(fengliuai): All the combinations are required to list. We need to
|
||||
// improve this.
|
||||
RegisterKernel("generic.reshape", {qi8_, any_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
RegisterKernel("generic.reshape", {any_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputSameScale);
|
||||
|
||||
RegisterKernel("generic.mul", {qi8_, qi8_, qi8_},
|
||||
quant::ScaleConstraintType::OutputInputFreeScale);
|
||||
RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_},
|
||||
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
|
||||
this, ph::_1, ph::_2, ph::_3, ph::_4));
|
||||
RegisterKernel("generic.matmul_add", {qi8_, qi8n_, any_, qi8_},
|
||||
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
|
||||
this, ph::_1, ph::_2, ph::_3, ph::_4));
|
||||
}
|
||||
|
||||
LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
|
||||
quant::QuantizeContext* ctx, Operation* op,
|
||||
quant::AdjacentOperations* new_items, bool* changed) {
|
||||
auto bias_params = ctx->GetOperandParams(op, 2);
|
||||
if (!EmptyParams(bias_params)) {
|
||||
return success();
|
||||
}
|
||||
std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
|
||||
ctx->GetOperandParams(op, 1)};
|
||||
auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
|
||||
if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
|
||||
*changed = true;
|
||||
new_items->push_back(op->getOperand(2).getDefiningOp());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -1,40 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
// Target specs for cpu kernels
|
||||
class CpuDeviceTarget : public quant::DeviceTarget {
|
||||
public:
|
||||
explicit CpuDeviceTarget(MLIRContext* ctx);
|
||||
|
||||
private:
|
||||
LogicalResult HandleMultiplyAccumulateScale(
|
||||
quant::QuantizeContext* ctx, Operation* op,
|
||||
quant::AdjacentOperations* new_items, bool* changed);
|
||||
};
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
|
@ -1,346 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
|
||||
#define DEBUG_TYPE "quant-kernel-fusion"
|
||||
|
||||
constexpr int kFakeQuantOperandsNum = 5;
|
||||
constexpr int kFakeQuantPerChannelOperandsNum = 6;
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
TypeAttr GetQuantSpec(Operation* op) {
|
||||
auto fake_quant = llvm::dyn_cast_or_null<CustomCallOp>(op);
|
||||
if (!fake_quant || fake_quant.getNumOperands() < kFakeQuantOperandsNum ||
|
||||
fake_quant.getNumOperands() > kFakeQuantPerChannelOperandsNum ||
|
||||
fake_quant.call_target_name() != "fake_quant_with_min_max_vars")
|
||||
return {};
|
||||
|
||||
DenseFPElementsAttr min, max;
|
||||
DenseIntElementsAttr bit_width, narrow_range, quant_dim;
|
||||
if (!matchPattern(fake_quant.getOperand(1), m_Constant(&min)) ||
|
||||
!matchPattern(fake_quant.getOperand(2), m_Constant(&max)) ||
|
||||
!matchPattern(fake_quant.getOperand(3), m_Constant(&bit_width)) ||
|
||||
!matchPattern(fake_quant.getOperand(4), m_Constant(&narrow_range)))
|
||||
return {};
|
||||
|
||||
auto bit_width_val = (*bit_width.attr_value_begin()).cast<IntegerAttr>();
|
||||
auto narrow_range_val = (*narrow_range.int_value_begin()).getSExtValue();
|
||||
int quant_dim_val = -1;
|
||||
if (fake_quant.getNumOperands() == kFakeQuantPerChannelOperandsNum &&
|
||||
matchPattern(fake_quant.getOperand(kFakeQuantPerChannelOperandsNum - 1),
|
||||
m_Constant(&quant_dim))) {
|
||||
quant_dim_val = (*quant_dim.int_value_begin()).getSExtValue();
|
||||
}
|
||||
|
||||
OpBuilder builder(op);
|
||||
Type input_type =
|
||||
fake_quant.getOperand(0).getType().cast<ShapedType>().getElementType();
|
||||
return quant::GetQuantizedTypeAttr(
|
||||
builder, input_type, min, max, quant_dim_val, bit_width_val,
|
||||
builder.getBoolAttr(narrow_range_val), /*is_signed=*/true);
|
||||
}
|
||||
|
||||
// Collects input values from outside for 'ops'.
|
||||
void CollectInputs(llvm::ArrayRef<Operation*> ops,
|
||||
llvm::SmallVectorImpl<Value>* inputs,
|
||||
llvm::SmallVectorImpl<Attribute>* input_specs) {
|
||||
for (Operation* op : ops) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) {
|
||||
continue;
|
||||
}
|
||||
if (Operation* def_op = operand.getDefiningOp()) {
|
||||
if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) {
|
||||
inputs->push_back(operand);
|
||||
}
|
||||
} else { // argument value
|
||||
inputs->push_back(operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Value input : *inputs) {
|
||||
ShapedType input_type = input.getType().cast<ShapedType>();
|
||||
if (TypeAttr spec = GetQuantSpec(input.getDefiningOp())) {
|
||||
input_specs->push_back(spec);
|
||||
} else {
|
||||
input_specs->push_back(TypeAttr::get(input_type.getElementType()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collects values that are produced by 'ops' and have use outside of 'ops'.
|
||||
// TODO(fengliuai): if it is a single user and QDQ, write that to the specs.
|
||||
void CollectRets(llvm::ArrayRef<Operation*> ops,
|
||||
llvm::SmallVectorImpl<Value>* rets,
|
||||
llvm::SmallVectorImpl<Type>* ret_types,
|
||||
llvm::SmallVectorImpl<Attribute>* ret_specs) {
|
||||
for (Operation* op : ops) {
|
||||
// The constant will not be shared outside the region.
|
||||
if (llvm::isa<ConstantOp>(op)) continue;
|
||||
|
||||
for (Value result : op->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
// If there are any user outside of 'ops'
|
||||
if (std::find(ops.begin(), ops.end(), user) == ops.end()) {
|
||||
ShapedType ret_type = result.getType().cast<ShapedType>();
|
||||
rets->push_back(result);
|
||||
ret_types->push_back(ret_type);
|
||||
if (TypeAttr spec = GetQuantSpec(user)) {
|
||||
ret_specs->push_back(spec);
|
||||
} else {
|
||||
ret_specs->push_back(TypeAttr::get(ret_type.getElementType()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum FusedActivationFunc { NONE, RELU, RELU1, RELU6 };
|
||||
|
||||
#define FLOAT_EQ(value, x) fabs(value - x) <= 1e-6
|
||||
|
||||
// If the op is max(in, 0.0), we consider this is from Relu, so both this op
|
||||
// and constant 0.0 will be fused.
|
||||
// If the op is clamp(0.0, in, 1.0) or clamp(0.0, in, 6.0), we consider this is
|
||||
// from Relu1 or Relu6, so all the constants and this op will be fused.
|
||||
// Returns the activation function type.
|
||||
FusedActivationFunc FuseReluX(Operation* op,
|
||||
llvm::SmallVectorImpl<Operation*>* fused) {
|
||||
if (auto max = llvm::dyn_cast<xla_hlo::MaxOp>(op)) {
|
||||
Value min_val = max.rhs();
|
||||
llvm::SmallVector<Operation*, 4> broadcast_ops;
|
||||
if (auto broadcast = llvm::dyn_cast_or_null<xla_hlo::BroadcastInDimOp>(
|
||||
min_val.getDefiningOp())) {
|
||||
min_val = broadcast.operand();
|
||||
broadcast_ops.push_back(broadcast);
|
||||
}
|
||||
DenseFPElementsAttr min;
|
||||
if (!matchPattern(min_val, m_Constant(&min))) {
|
||||
// In case the min value is lhs.
|
||||
min_val = max.lhs();
|
||||
broadcast_ops.clear();
|
||||
if (auto broadcast = llvm::dyn_cast_or_null<xla_hlo::BroadcastInDimOp>(
|
||||
min_val.getDefiningOp())) {
|
||||
min_val = broadcast.operand();
|
||||
broadcast_ops.push_back(broadcast);
|
||||
}
|
||||
if (!matchPattern(min_val, m_Constant(&min))) {
|
||||
return NONE;
|
||||
}
|
||||
}
|
||||
if (!min.isSplat() ||
|
||||
!(FLOAT_EQ(min.getSplatValue().cast<FloatAttr>().getValueAsDouble(),
|
||||
0.0))) {
|
||||
return NONE;
|
||||
}
|
||||
|
||||
// Include the constant 0.0 as well, to avoid being quantized.
|
||||
fused->push_back(min_val.getDefiningOp());
|
||||
fused->append(broadcast_ops.begin(), broadcast_ops.end());
|
||||
fused->push_back(max);
|
||||
return RELU;
|
||||
}
|
||||
|
||||
if (auto clamp = llvm::dyn_cast<xla_hlo::ClampOp>(op)) {
|
||||
DenseFPElementsAttr lower, upper;
|
||||
if (!matchPattern(clamp.min(), m_Constant(&lower)) ||
|
||||
!matchPattern(clamp.max(), m_Constant(&upper)) || !lower.isSplat() ||
|
||||
!upper.isSplat() ||
|
||||
!(FLOAT_EQ(lower.getSplatValue().cast<FloatAttr>().getValueAsDouble(),
|
||||
0.0))) {
|
||||
return NONE;
|
||||
}
|
||||
|
||||
double upper_value =
|
||||
upper.getSplatValue().cast<FloatAttr>().getValueAsDouble();
|
||||
if (FLOAT_EQ(upper_value, 1.0) || FLOAT_EQ(upper_value, 6.0)) {
|
||||
fused->push_back(clamp.min().getDefiningOp());
|
||||
fused->push_back(clamp.max().getDefiningOp());
|
||||
fused->push_back(op);
|
||||
return (FLOAT_EQ(upper_value, 1.0) ? RELU1 : RELU6);
|
||||
}
|
||||
}
|
||||
return NONE;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 0> FuseOps(PatternRewriter* rewriter,
|
||||
const std::initializer_list<Value>& results,
|
||||
StringRef kernel) {
|
||||
// Collect all the operations to be fused.
|
||||
llvm::SmallVector<Operation*, 4> fused;
|
||||
llvm::SmallVector<Location, 4> locs;
|
||||
fused.reserve(results.size());
|
||||
locs.reserve(results.size());
|
||||
for (auto value : results) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
fused.push_back(op);
|
||||
locs.push_back(op->getLoc());
|
||||
}
|
||||
|
||||
Operation* root = fused.back();
|
||||
|
||||
FusedActivationFunc act_func = FusedActivationFunc::NONE;
|
||||
// If there is Relu, Relu1 or Relu6, fuse it as well.
|
||||
if (results.size() > 0 && std::rbegin(results)->hasOneUse()) {
|
||||
act_func = FuseReluX(*std::rbegin(results)->user_begin(), &fused);
|
||||
}
|
||||
|
||||
// Collect inputs from outside to 'ops'.
|
||||
llvm::SmallVector<Value, 4> inputs;
|
||||
llvm::SmallVector<Attribute, 4> input_specs;
|
||||
CollectInputs(fused, &inputs, &input_specs);
|
||||
|
||||
// Collect outputs from 'ops' to outside.
|
||||
llvm::SmallVector<Value, 4> rets;
|
||||
llvm::SmallVector<Type, 4> ret_types;
|
||||
llvm::SmallVector<Attribute, 4> ret_specs;
|
||||
CollectRets(fused, &rets, &ret_types, &ret_specs);
|
||||
|
||||
// TODO(fengliuai): make activation function an attribute.
|
||||
std::string kernel_name;
|
||||
switch (act_func) {
|
||||
case RELU:
|
||||
kernel_name = llvm::Twine(kernel, "_relu").str();
|
||||
break;
|
||||
case RELU1:
|
||||
kernel_name = llvm::Twine(kernel, "_relu1").str();
|
||||
break;
|
||||
case RELU6:
|
||||
kernel_name = llvm::Twine(kernel, "_relu6").str();
|
||||
break;
|
||||
default:
|
||||
kernel_name = kernel.str();
|
||||
}
|
||||
|
||||
// Create the region op with the return.
|
||||
auto region = rewriter->create<quant::QuantizeRegionOp>(
|
||||
rewriter->getFusedLoc(locs), ret_types, inputs,
|
||||
rewriter->getArrayAttr(input_specs), rewriter->getArrayAttr(ret_specs),
|
||||
kernel_name);
|
||||
auto* body = new Block();
|
||||
region.body().push_back(body);
|
||||
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(body);
|
||||
BlockAndValueMapping mapping;
|
||||
|
||||
// Make block arguments and add it to the block value mapping.
|
||||
for (Value input : inputs) {
|
||||
mapping.map(input, body->addArgument(input.getType()));
|
||||
}
|
||||
|
||||
// Clone the operations 'ops' to the region.
|
||||
for (Operation* op : fused) {
|
||||
builder.clone(*op, mapping);
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> new_rets;
|
||||
new_rets.reserve(rets.size());
|
||||
for (auto ret : llvm::enumerate(rets)) {
|
||||
Value new_ret = mapping.lookupOrNull(ret.value());
|
||||
assert(new_ret && "couldn't find return value.");
|
||||
new_rets.push_back(new_ret);
|
||||
ret.value().replaceAllUsesWith(region.getResult(ret.index()));
|
||||
}
|
||||
builder.create<quant::ReturnOp>(builder.getUnknownLoc(), new_rets);
|
||||
|
||||
LLVM_DEBUG({
|
||||
assert(region.verify().Success && "failed to create quant region.");
|
||||
llvm::dbgs() << "\ncreated region: ";
|
||||
region.print(llvm::dbgs());
|
||||
llvm::dbgs() << "\n\n\n";
|
||||
});
|
||||
|
||||
// All uses of the fused ops are replaced, so the values in this vector
|
||||
// will not be used.
|
||||
SmallVector<Value, 0> new_values(root->getNumResults(), region.getResult(0));
|
||||
return new_values;
|
||||
}
|
||||
|
||||
struct CpuKernelFusionPass : public FunctionPass<CpuKernelFusionPass> {
|
||||
explicit CpuKernelFusionPass() = default;
|
||||
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/generated_cpu_kernel_fusion.inc"
|
||||
|
||||
void CpuKernelFusionPass::runOnFunction() {
|
||||
Operation* op = getOperation();
|
||||
MLIRContext* ctx = op->getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
applyPatternsGreedily(op->getRegions(), patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo cpu kernel fusion pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateCpuKernelFusionPass() {
|
||||
return std::make_unique<CpuKernelFusionPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<CpuKernelFusionPass> pass(
|
||||
"xla-hlo-cpu-fusion", "Fuse xla hlo ops into cpu kernels");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -1,65 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
|
||||
class Fused1Ops<string kernel> : NativeCodeCall<
|
||||
"FuseOps(&$_builder, {$0}, \"" # kernel # "\")">;
|
||||
class Fused2Ops<string kernel> : NativeCodeCall<
|
||||
"FuseOps(&$_builder, {$0, $1}, \"" # kernel # "\")">;
|
||||
class Fused3Ops<string kernel> : NativeCodeCall<
|
||||
"FuseOps(&$_builder, {$0, $1, $2}, \"" # kernel # "\")">;
|
||||
class Fused4Ops<string kernel> : NativeCodeCall<
|
||||
"FuseOps(&$_builder, {$0, $1, $2, $3}, \"" # kernel # "\")">;
|
||||
|
||||
// We shouldn't revisit those ops which have been fused. This constraint is
|
||||
// required because the greedy pattern rewriter will visit and match any new
|
||||
// ops. So when the source pattern are matched and wrapped by the quant region
|
||||
// op, these ops will be matched again. To prevent this, this constraint is
|
||||
// added to bypass any ops inside a quant region.
|
||||
def NeedsToBeFused : Constraint<CPred<
|
||||
"!$0.getDefiningOp()->getParentOfType<quant::QuantizeRegionOp>()">>;
|
||||
|
||||
// dummy example
|
||||
def : Pat<(HLO_AddOp:$add (HLO_MulOp:$mul $_, $_, $_), $_, $_),
|
||||
(Fused2Ops<"generic.mul_add"> $mul, $add),
|
||||
[(NeedsToBeFused $add)]>;
|
||||
|
||||
// reduce_window: maxpool, avgpool
|
||||
def : Pat<(HLO_ReduceWindowOp:$reduce $_, $_, $_, $_, $_, $_, $_),
|
||||
(Fused1Ops<"generic.reduce_window"> $reduce),
|
||||
[(NeedsToBeFused $reduce)]>;
|
||||
|
||||
// reshape
|
||||
def : Pat<(HLO_ReshapeOp:$reshape $_), (Fused1Ops<"generic.reshape"> $reshape),
|
||||
[(NeedsToBeFused $reshape)]>;
|
||||
|
||||
// broadcast
|
||||
def : Pat<(HLO_BroadcastInDimOp:$broadcast $_, $_),
|
||||
(Fused1Ops<"generic.broadcast"> $broadcast),
|
||||
[(NeedsToBeFused $broadcast)]>;
|
||||
|
||||
// dot -> add
|
||||
def : Pat<(HLO_AddOp:$add (HLO_DotOp:$dot $_, $_, $_), $_, $_),
|
||||
(Fused2Ops<"generic.biased_dot"> $dot, $add),
|
||||
[(NeedsToBeFused $add)]>;
|
||||
|
||||
// conv -> add
|
||||
def : Pat<(HLO_AddOp:$add
|
||||
(HLO_ConvOp:$conv $_, $_, $_, $_, $_, $_, $_, $_, $_, $_), $_, $_),
|
||||
(Fused2Ops<"generic.biased_conv"> $conv, $add),
|
||||
[(NeedsToBeFused $add)]>;
|
@ -1,174 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass quantize the constant and rewrite the quantization
|
||||
// ops by xla_hlo primitive ops.
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to materialize the quantization results by xla primitive ops.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pattern matches the "constant->qcast->dcast" pattern and replaces it by
|
||||
// "quantized constant->xla_hlo.dequantize". If it only matches the
|
||||
// "non-constant->qcast->dcast" pattern, it will remove both the "qcast->dcast".
|
||||
// We chain the pattern as a whole to bypass the type checks of the normal
|
||||
// xla_hlo ops.
|
||||
// TODO(fengliuai): make this pass work for bf16 input.
|
||||
class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
public:
|
||||
explicit RewriteDequantize(int64_t size, MLIRContext *context)
|
||||
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
|
||||
|
||||
LogicalResult matchAndRewrite(quant::DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// quant.dcast
|
||||
// xla_hlo dequantize only takes min/max, so let's recover them from
|
||||
// the quantization parameters.
|
||||
Value dcast = op.arg();
|
||||
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
|
||||
if (!type || !type.isa<quant::UniformQuantizedType>()) {
|
||||
return failure();
|
||||
}
|
||||
auto qtype = type.cast<quant::UniformQuantizedType>();
|
||||
double scale = qtype.getScale();
|
||||
int64_t zero_point = qtype.getZeroPoint();
|
||||
float min = scale * (qtype.getStorageTypeMin() - zero_point);
|
||||
float max = scale * (qtype.getStorageTypeMax() - zero_point);
|
||||
|
||||
// quant.qcast
|
||||
auto qcast =
|
||||
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
|
||||
if (!qcast) return failure();
|
||||
|
||||
// constant
|
||||
DenseFPElementsAttr attr;
|
||||
// If it isn't a floating-point constant or the size is too small, let's
|
||||
// remove the quantization. Also the last dimension size should be a
|
||||
// multiplier of 4, so the shape isn't broken during packing and unpacking.
|
||||
if (!matchPattern(qcast.arg(), m_Constant(&attr)) ||
|
||||
attr.getNumElements() <= size_ ||
|
||||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
|
||||
op.getResult().replaceAllUsesWith(qcast.arg());
|
||||
return success();
|
||||
}
|
||||
// TODO(fengliuai): implement transpose if it has high dimension.
|
||||
|
||||
// Create the quantized result
|
||||
auto quantized_result =
|
||||
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!quantized_result) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Pack the uint8 bits to uint32. The shape is changed from from
|
||||
// [n0, n1, ..., nk] to [n0, n1, ..., nk / 4].
|
||||
std::vector<uint8_t> raw_data;
|
||||
for (auto d : quantized_result.getValues<uint8_t>()) {
|
||||
raw_data.push_back(d);
|
||||
}
|
||||
// The packing might increase the data size by paddings.
|
||||
auto packed_data = xla::PackToUint32<uint8_t>(raw_data);
|
||||
auto packed_shape = attr.getType().getShape().vec();
|
||||
int lower_dims = std::accumulate(
|
||||
packed_shape.begin(),
|
||||
std::next(packed_shape.begin(), packed_shape.size() - 1), 1,
|
||||
std::multiplies<int>());
|
||||
packed_shape[packed_shape.size() - 1] = packed_data.size() / lower_dims;
|
||||
auto packed_type =
|
||||
RankedTensorType::get(packed_shape, rewriter.getIntegerType(32));
|
||||
|
||||
auto packed_quantized_result =
|
||||
DenseElementsAttr::get<uint32_t>(packed_type, packed_data);
|
||||
auto quantized_constant =
|
||||
rewriter.create<ConstantOp>(qcast.getLoc(), packed_quantized_result);
|
||||
|
||||
// Create the xla dequantize op with bf16 output
|
||||
auto dequantized_type = RankedTensorType::get(attr.getType().getShape(),
|
||||
rewriter.getBF16Type());
|
||||
auto dequantize = rewriter.create<DequantizeOp>(
|
||||
qcast.getLoc(), dequantized_type, quantized_constant,
|
||||
rewriter.getF32FloatAttr(min), rewriter.getF32FloatAttr(max),
|
||||
rewriter.getStringAttr("MIN_COMBINED"), rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
// Convert bf16 output back to f32
|
||||
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
|
||||
dequantize);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t size_;
|
||||
};
|
||||
|
||||
// Materialize the quantization results by hlo primitive ops.
|
||||
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
|
||||
explicit MaterializeToXlaPass() = default;
|
||||
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void MaterializeToXlaPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
// TODO(fengliuai): make the size 6 configurable.
|
||||
patterns.insert<RewriteDequantize>(6, ctx);
|
||||
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
|
||||
return std::make_unique<MaterializeToXlaPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<MaterializeToXlaPass> pass(
|
||||
"xla-hlo-materialize-quant",
|
||||
"Materialize the quantization results by xla primitve ops");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -1,7 +0,0 @@
|
||||
// TODO(fengliuai): automatically generate this file
|
||||
// TODO(fengliuai): add all the xla_hlo ops
|
||||
|
||||
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
|
||||
auto spec = absl::make_unique<quant::OpQuantSpec>();
|
||||
return spec;
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
// Propagate the quantization information to all the tensors according to the
|
||||
// op quant spec.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass();
|
||||
|
||||
// Rewrite the graph and quantize the constant.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
@ -1,107 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization propagation on xla_hlo dialect.
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> disable_per_channel(
|
||||
"xla-disable-per-channel", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("Whether disable per-channel quantized weights."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The quantization propagation Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// Applies the quantization propagation on the input function. During the
|
||||
// propagation, two facts are respected:
|
||||
// - The quantization type (params) of the ops in the function
|
||||
// - The quantization spec for the ops
|
||||
// The propagation results should assign quantization types to all the tensors
|
||||
// and the two restrictions are respected.
|
||||
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
|
||||
explicit PropagateQuantPass() = default;
|
||||
PropagateQuantPass(const PropagateQuantPass &) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc"
|
||||
|
||||
void PropagateQuantPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
// TODO(fengliuai): deprecate this old code generation path.
|
||||
// XLA only support uint8/uint16 quantization for now.
|
||||
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
|
||||
disable_per_channel, GetOpQuantSpec);
|
||||
|
||||
CpuDeviceTarget spec(&getContext());
|
||||
quant::QuantizeContext ctx(func, spec);
|
||||
|
||||
std::vector<quant::QuantizeRegionOp> work_list = ctx.GetAllOps();
|
||||
bool changed = false;
|
||||
while (!work_list.empty()) {
|
||||
quant::QuantizeRegionOp op = work_list.back();
|
||||
work_list.pop_back();
|
||||
|
||||
llvm::SmallVector<Operation *, 4> new_items;
|
||||
if (failed(ctx.Handle(op, &new_items, &changed))) {
|
||||
// The IR is still valid, thus we shouldn't fail.
|
||||
signalPassFailure();
|
||||
}
|
||||
for (auto item : new_items) {
|
||||
if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
|
||||
work_list.push_back(reg);
|
||||
}
|
||||
}
|
||||
|
||||
if (!changed) return;
|
||||
|
||||
if (failed(ctx.Finalize())) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
|
||||
return std::make_unique<PropagateQuantPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<PropagateQuantPass> pass(
|
||||
"xla-hlo-propagate-quant", "Propagate quantization information");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -1,74 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
|
||||
computation->Snapshot());
|
||||
|
||||
RegisterDialects();
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
auto status = xla::ConvertHloToMlirHlo(
|
||||
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Hlo module import failed: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
PassManager pm(&context);
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createInlinerPass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
|
||||
LogicalResult result = pm.run(module.get());
|
||||
(void)result;
|
||||
|
||||
module->dump();
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -1,35 +0,0 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
":graph_config_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
|
||||
},
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/aot:tfcompile",
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all the graph files that are used by the tests.
|
||||
filegroup(
|
||||
name = "graph_config_files",
|
||||
srcs = glob(
|
||||
["**/*.pbtxt"],
|
||||
),
|
||||
)
|
@ -1,199 +0,0 @@
|
||||
// RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @mul_add_source
|
||||
func @mul_add_source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
|
||||
// CHECK: %[[region:.*]] = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
// CHECK: ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
// CHECK: "quant.return"(%[[add]]) : (tensor<4xf32>) -> ()
|
||||
// CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: return %[[region]] : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated
|
||||
func @mul_add_annotated(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) {
|
||||
%cst = constant dense<0.0> : tensor<f32>
|
||||
%cst_0 = constant dense<255.0> : tensor<f32>
|
||||
%cst_1 = constant dense<8> : tensor<i32>
|
||||
%cst_2 = constant dense<false> : tensor<i1>
|
||||
%qin = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.1"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
%qw = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.2"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
%0 = "xla_hlo.multiply"(%qin, %qw) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
%r = "xla_hlo.custom_call"(%1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
|
||||
has_side_effect = false, name = "custom-call.3"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
return %r : tensor<2x4xf32>
|
||||
|
||||
// CHECK: %[[region:.*]] = "quant.region"
|
||||
// CHECK: ^bb0(%arg3: tensor<2x4xf32>, %arg4: tensor<2x4xf32>, %arg5: tensor<2x4xf32>): // no predecessors
|
||||
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
|
||||
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
|
||||
// CHECK: "quant.return"(%[[add]]) : (tensor<2x4xf32>) -> ()
|
||||
// CHECK: }) {input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32],
|
||||
// CHECK-SAME: logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]} :
|
||||
// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[r:.*]] = "xla_hlo.custom_call"(%[[region]]
|
||||
// CHECK: return %[[r]] : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @reduce_window
|
||||
func @reduce_window(%arg0: tensor<1x28x28x32xf32>, %arg1: tensor<f32>) -> (tensor<1x14x14x32xf32>) {
|
||||
%0 = "xla_hlo.reduce_window"(%arg0, %arg1) ({
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%1 = xla_hlo.maximum %arg2, %arg3 : tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
base_dilations = dense<1> : tensor<4xi64>,
|
||||
padding = dense<0> : tensor<4x2xi64>,
|
||||
window_dilations = dense<1> : tensor<4xi64>,
|
||||
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
|
||||
} : (tensor<1x28x28x32xf32>, tensor<f32>) -> tensor<1x14x14x32xf32>
|
||||
return %0 : tensor<1x14x14x32xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0, %arg1) ( {
|
||||
// CHECK: ^bb0(%arg2: tensor<1x28x28x32xf32>, %arg3: tensor<f32>): // no predecessors
|
||||
// CHECK: %[[rw:.*]] = "xla_hlo.reduce_window"(%arg2, %arg3) ( {
|
||||
// CHECK: ^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>): // no predecessors
|
||||
// CHECK: %[[max:.*]] = xla_hlo.maximum %arg4, %arg5 : tensor<f32>
|
||||
// CHECK: "xla_hlo.return"(%[[max]]) : (tensor<f32>) -> ()
|
||||
// CHECK: })
|
||||
// CHECK: "quant.return"(%[[rw]])
|
||||
// CHECK: }) {input_specs = [f32, f32], logical_kernel = "generic.reduce_window", output_specs = [f32]}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @reshape
|
||||
func @reshape(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32>
|
||||
return %0 : tensor<1x3136xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0)
|
||||
// CHECK: logical_kernel = "generic.reshape"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast
|
||||
func @broadcast(%arg0: tensor<64xf32>) -> (tensor<1x14x14x64xf32>) {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<1x14x14x64xf32>
|
||||
return %0 : tensor<1x14x14x64xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0)
|
||||
// CHECK: logical_kernel = "generic.broadcast"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @biased_dot
|
||||
func @biased_dot(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x10xf32>, %arg2: tensor<1x10xf32>) -> (tensor<1x10xf32>) {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1024xf32>, tensor<1024x10xf32>) -> tensor<1x10xf32>
|
||||
%1 = xla_hlo.add %0, %arg2 : tensor<1x10xf32>
|
||||
return %1 : tensor<1x10xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
|
||||
// CHECK: xla_hlo.dot
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: logical_kernel = "generic.biased_dot"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @biased_conv
|
||||
func @biased_conv(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) {
|
||||
%0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>,
|
||||
padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32>
|
||||
%1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32>
|
||||
return %1 : tensor<1x14x14x64xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
|
||||
// CHECK: xla_hlo.conv
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: logical_kernel = "generic.biased_conv"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @biased_dot_relu
|
||||
func @biased_dot_relu(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x10xf32>, %arg2: tensor<1x10xf32>) -> (tensor<1x10xf32>) {
|
||||
%cst = constant dense<0.0> : tensor<1x10xf32>
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1024xf32>, tensor<1024x10xf32>) -> tensor<1x10xf32>
|
||||
%1 = xla_hlo.add %0, %arg2 : tensor<1x10xf32>
|
||||
%2 = xla_hlo.maximum %1, %cst : tensor<1x10xf32>
|
||||
return %2 : tensor<1x10xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
|
||||
// CHECK: constant
|
||||
// CHECK: xla_hlo.dot
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: xla_hlo.maximum
|
||||
// CHECK: logical_kernel = "generic.biased_dot_relu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @biased_conv_relu
|
||||
func @biased_conv_relu(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) {
|
||||
%cst = constant dense<0.0> : tensor<1x14x14x64xf32>
|
||||
%0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>,
|
||||
padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32>
|
||||
%1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32>
|
||||
%2 = xla_hlo.maximum %1, %cst : tensor<1x14x14x64xf32>
|
||||
return %2 : tensor<1x14x14x64xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
|
||||
// CHECK: constant
|
||||
// CHECK: xla_hlo.conv
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: xla_hlo.maximum
|
||||
// CHECK: logical_kernel = "generic.biased_conv_relu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @biased_conv_relu_shared
|
||||
func @biased_conv_relu_shared(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>) {
|
||||
%cst = constant dense<0.0> : tensor<1x14x14x64xf32>
|
||||
%0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>,
|
||||
padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32>
|
||||
%1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32>
|
||||
%2 = xla_hlo.maximum %1, %cst : tensor<1x14x14x64xf32>
|
||||
return %cst, %2 : tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
|
||||
// CHECK: constant
|
||||
// CHECK: xla_hlo.conv
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: %[[max:.*]] = xla_hlo.maximum
|
||||
// CHECK: "quant.return"(%[[max]])
|
||||
// CHECK: logical_kernel = "generic.biased_conv_relu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @biased_conv_relu6
|
||||
func @biased_conv_relu6(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) {
|
||||
%min = constant dense<0.0> : tensor<1x14x14x64xf32>
|
||||
%max = constant dense<6.0> : tensor<1x14x14x64xf32>
|
||||
%0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>,
|
||||
padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32>
|
||||
%1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32>
|
||||
%2 = "xla_hlo.clamp"(%min, %1, %max) : (tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>) -> tensor<1x14x14x64xf32>
|
||||
return %2 : tensor<1x14x14x64xf32>
|
||||
|
||||
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
|
||||
// CHECK: constant
|
||||
// CHECK: constant
|
||||
// CHECK: xla_hlo.conv
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: xla_hlo.clamp
|
||||
// CHECK: logical_kernel = "generic.biased_conv_relu6"
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --experimental_quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
|
||||
# TODO(fengliuai): update this file with the progress of the implementation
|
||||
// CHECK: func @main
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
|
||||
// CHECK: %cst_1 = constant dense<8> : tensor<i32>
|
||||
// CHECK: %cst_2 = constant dense<false> : tensor<i1>
|
||||
// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
|
||||
// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
|
||||
// CHECK: return %4 : tuple<tensor<2x4xf32>>
|
||||
// CHECK: }
|
@ -1,26 +0,0 @@
|
||||
feed {
|
||||
id { node_name: "input0" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "input1" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
fetch {
|
||||
id { node_name: "Add/FakeQuantWithMinMaxVars" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
conversion_options {
|
||||
custom_fake_quant_op_calls: true
|
||||
}
|
@ -1,218 +0,0 @@
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "Add"
|
||||
input: "Add/FakeQuantWithMinMaxVars/min"
|
||||
input: "Add/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "input0/FakeQuantWithMinMaxVars"
|
||||
input: "input1/FakeQuantWithMinMaxVars"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "input0"
|
||||
input: "input0/FakeQuantWithMinMaxVars/min"
|
||||
input: "input0/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "input1"
|
||||
input: "input1/FakeQuantWithMinMaxVars/min"
|
||||
input: "input1/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
@ -1,54 +0,0 @@
|
||||
// RUN: tf-opt -xla-hlo-materialize-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @quantize_rewrite
|
||||
func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK: %[[qcst:.*]] = constant dense<{{\[\[}}21004416], [-1056997248]]> : tensor<2x1xi32>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
|
||||
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
|
||||
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[cast]] : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_small
|
||||
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<1x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<1x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<1x4xf32>
|
||||
return %mul: tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_cst
|
||||
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %arg0 : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_4x
|
||||
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<2x5xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<2x5xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x5xf32>
|
||||
return %mul: tensor<2x5xf32>
|
||||
}
|
@ -1,69 +0,0 @@
|
||||
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_source_no_params
|
||||
func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [f32, f32, f32]
|
||||
// CHECK-SAME: output_specs = [f32]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated_no_narrow_range
|
||||
func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
|
||||
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @mul_add_annotated
|
||||
func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
|
||||
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
|
||||
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
|
||||
"quant.return"(%add) : (tensor<4xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
|
||||
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
|
||||
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %region : tensor<4xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @same_scale_1_1
|
||||
func @same_scale_1_1(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) {
|
||||
%region = "quant.region"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<1x7x7x64xf32>): // no predecessors
|
||||
%r = "xla_hlo.reshape"(%arg1) : (tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>)
|
||||
"quant.return"(%r) : (tensor<1x3136xf32>) -> ()
|
||||
}) {input_specs = [!quant.uniform<i8:f32, 1.0>], logical_kernel = "generic.reshape", output_specs = [f32]} : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32>
|
||||
return %region : tensor<1x3136xf32>
|
||||
|
||||
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
|
||||
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
|
||||
}
|
@ -1,25 +0,0 @@
|
||||
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul
|
||||
func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
|
||||
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
|
||||
%mul = xla_hlo.multiply %arg0, %w : tensor<2x2xf32>
|
||||
return %mul: tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add
|
||||
func @add(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[b:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[b]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>) -> tensor<2xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg0, %[[dq]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[add]] : tensor<2x2xf32>
|
||||
%b = constant dense<1.0> : tensor<2xf32>
|
||||
%add = "xla_hlo.add"(%arg0, %b) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
return %add: tensor<2x2xf32>
|
||||
}
|
@ -39,7 +39,7 @@ versions {
|
||||
# CHECK-LABEL: func @main
|
||||
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-SAME: control_outputs = ""
|
||||
# CHECK-SAME inputs = "input0,input1"
|
||||
# CHECK-SAME: inputs = "input0,input1"
|
||||
# CHECK-SAME: outputs = "output"
|
||||
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
|
||||
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>
|
||||
|
@ -12,6 +12,7 @@ glob_lit_tests(
|
||||
test_file_exts = [
|
||||
"mlir",
|
||||
"cc",
|
||||
"json",
|
||||
],
|
||||
)
|
||||
|
||||
@ -24,6 +25,8 @@ filegroup(
|
||||
":importer_test_min_max",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||
"//tensorflow/compiler/mlir/lite:json_to_flatbuffer",
|
||||
"//tensorflow/lite/schema:schema.fbs",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,83 @@
|
||||
// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
// CHECK: %cst = constant unit
|
||||
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
|
||||
// CHECK: return %[[RES0]] : tensor<256x32x32x16xf32>
|
||||
|
||||
{
|
||||
version: 3,
|
||||
operator_codes: [
|
||||
{
|
||||
builtin_code: "CONV_2D",
|
||||
}
|
||||
],
|
||||
subgraphs: [
|
||||
{
|
||||
tensors: [
|
||||
{
|
||||
shape: [
|
||||
256,
|
||||
32,
|
||||
32,
|
||||
3
|
||||
],
|
||||
name: "arg0",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
16,
|
||||
3,
|
||||
3,
|
||||
3
|
||||
],
|
||||
name: "arg1",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
0
|
||||
],
|
||||
name: "cst"
|
||||
},
|
||||
{
|
||||
shape: [
|
||||
256,
|
||||
32,
|
||||
32,
|
||||
16
|
||||
],
|
||||
name: "output",
|
||||
quantization: {
|
||||
}
|
||||
},
|
||||
],
|
||||
inputs: [
|
||||
0,
|
||||
1
|
||||
],
|
||||
outputs: [
|
||||
3
|
||||
],
|
||||
operators: [
|
||||
{
|
||||
inputs: [
|
||||
0,
|
||||
1,
|
||||
-1
|
||||
],
|
||||
outputs: [
|
||||
3
|
||||
],
|
||||
builtin_options_type: "Conv2DOptions",
|
||||
builtin_options: {
|
||||
}
|
||||
}
|
||||
],
|
||||
name: "main"
|
||||
}
|
||||
],
|
||||
description: "MLIR Converted."
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail
|
||||
// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s --dump-input=fail
|
||||
|
||||
// Inline a function that contains only tfl ops.
|
||||
func @func_with_tfl_ops(%arg0 : tensor<2xi32>) -> tensor<2xi32> {
|
||||
|
@ -1,5 +1,5 @@
|
||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline --mlir-disable-inline-simplify | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
|
||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline="disable-simplify" | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
|
||||
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline | FileCheck %s --dump-input-on-failure --check-prefix=CANON
|
||||
|
||||
func @while_main(%arg0: tensor<?x256x256xf32>) -> (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} {
|
||||
|
@ -9,6 +9,20 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddHighDimsHaveSameShape
|
||||
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddTooHighBroadcastableDims
|
||||
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %2: tensor<1xf32>
|
||||
@ -1448,7 +1462,7 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK return [[MUL]] : tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
@ -1459,5 +1473,5 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK return [[MUL]] : tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
@ -134,6 +134,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
// Legalize while early to allow further constant folding.
|
||||
// TODO(jpienaar): This may not actually matter as we do canonicalization
|
||||
// after the legalize below, for now it needs to be below the above passes
|
||||
@ -160,11 +164,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// constant ops.
|
||||
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
|
||||
// The below passes only make sense if Builtin TFLite ops are enabled
|
||||
// for emission.
|
||||
if (pass_config.emit_builtin_tflite_ops) {
|
||||
@ -173,7 +172,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
// This pass operates on TensorFlow ops but is triggered after legalization
|
||||
// so that it can target constants introduced once TensorFlow Identity ops
|
||||
@ -255,7 +255,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
// TFLite dialect passes.
|
||||
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||
pm.addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
|
||||
pm.addPass(mlir::TFL::CreateOptimizePass());
|
||||
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
|
||||
@ -268,7 +269,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
}
|
||||
|
||||
// Registers a pass pipeline for the standard TFL passes.
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
@ -214,7 +216,7 @@ int main(int argc, char **argv) {
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
|
@ -40,7 +40,7 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -44,7 +44,8 @@ namespace TFL {
|
||||
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
|
||||
|
||||
namespace {
|
||||
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
|
||||
class DefaultQuantParamsPass
|
||||
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
|
||||
public:
|
||||
explicit DefaultQuantParamsPass(double default_min, double default_max)
|
||||
: default_min_(default_min), default_max_(default_max) {}
|
||||
@ -220,7 +221,7 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
||||
}
|
||||
|
||||
// Creates an instance of the default quant parameters pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max) {
|
||||
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
struct DenseToSparse : public FunctionPass<DenseToSparse> {
|
||||
struct DenseToSparse : public PassWrapper<DenseToSparse, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
@ -63,7 +63,7 @@ void DenseToSparse::runOnFunction() {
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect DenseToSparse pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass() {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDenseToSparsePass() {
|
||||
return absl::make_unique<DenseToSparse>();
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,8 @@ namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
struct IdentifyDilatedConvPass : public FunctionPass<IdentifyDilatedConvPass> {
|
||||
struct IdentifyDilatedConvPass
|
||||
: public PassWrapper<IdentifyDilatedConvPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
|
@ -679,7 +679,8 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
||||
return success();
|
||||
}
|
||||
|
||||
struct ExtractOphintPass : public OperationPass<ExtractOphintPass, ModuleOp> {
|
||||
struct ExtractOphintPass
|
||||
: public PassWrapper<ExtractOphintPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
void Verify();
|
||||
|
||||
@ -752,7 +753,7 @@ void ExtractOphintPass::Verify() {
|
||||
|
||||
/// Creates an instance of the TensorFlow Lite dialect ExtractOphintPass
|
||||
/// pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOphintPass() {
|
||||
return std::make_unique<ExtractOphintPass>();
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,7 @@ constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm";
|
||||
// |
|
||||
// OutputOp1
|
||||
struct LegalizeOphintFuncOpPass
|
||||
: public OperationPass<LegalizeOphintFuncOpPass, ModuleOp> {
|
||||
: public PassWrapper<LegalizeOphintFuncOpPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
@ -284,7 +284,7 @@ void LegalizeOphintFuncOpPass::runOnOperation() {
|
||||
|
||||
/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
|
||||
/// pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeOphintFuncOpPass() {
|
||||
return std::make_unique<LegalizeOphintFuncOpPass>();
|
||||
}
|
||||
|
||||
|
@ -70,8 +70,21 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
|
||||
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
||||
|
||||
// Legalize operations in functions.
|
||||
struct LegalizeTF : public FunctionPass<LegalizeTF> {
|
||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
public:
|
||||
LegalizeTF() = default;
|
||||
LegalizeTF(const LegalizeTF&) {}
|
||||
explicit LegalizeTF(bool run_tfl_runtime_verification) {
|
||||
run_tfl_runtime_verification_ = run_tfl_runtime_verification;
|
||||
}
|
||||
|
||||
/// Performs the lowering to TFLite dialect.
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
Option<bool> run_tfl_runtime_verification_{
|
||||
*this, "run-tfl-runtime-verification",
|
||||
llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
|
||||
};
|
||||
|
||||
// Returns true if all tensor value in `values` has static shape and same shape.
|
||||
@ -314,7 +327,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
||||
// can't do any padding. Instead we just return it.
|
||||
return attribute;
|
||||
}
|
||||
for (auto idx : dense_elem_attr.getIntValues()) {
|
||||
for (const auto& idx : dense_elem_attr.getIntValues()) {
|
||||
padded_val.push_back(idx.getSExtValue());
|
||||
}
|
||||
auto attr_dim_count = ranked_attr_type.getShape()[0];
|
||||
@ -440,7 +453,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
|
||||
if (!matchPattern(tf_matrix_diag_v2_or_v3_op.padding_value(),
|
||||
m_Constant(&padding_value)))
|
||||
return false;
|
||||
for (auto value : padding_value.getValues<APInt>()) {
|
||||
for (const auto& value : padding_value.getValues<APInt>()) {
|
||||
if (value != 0) return false;
|
||||
}
|
||||
|
||||
@ -741,13 +754,19 @@ void LegalizeTF::runOnFunction() {
|
||||
// graph.
|
||||
target.addLegalOp<mlir::ConstantOp>();
|
||||
target.addLegalOp<ConstOp>();
|
||||
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
|
||||
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
|
||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||
if (!tfl_op) return false;
|
||||
return succeeded(tfl_op.VerifyTflRuntimeTypes(
|
||||
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
|
||||
}));
|
||||
if (run_tfl_runtime_verification_) {
|
||||
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
|
||||
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
|
||||
[](Operation* op) {
|
||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||
if (!tfl_op) return false;
|
||||
return succeeded(tfl_op.VerifyTflRuntimeConstraints(
|
||||
tfl_op.getOperation(),
|
||||
/*failure_on_operand_type_mismatch=*/false));
|
||||
}));
|
||||
} else {
|
||||
target.addLegalDialect<TensorFlowLiteDialect>();
|
||||
}
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
// Look if there is a function that tries until it converge.
|
||||
@ -763,8 +782,9 @@ void LegalizeTF::runOnFunction() {
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFPass() {
|
||||
return std::make_unique<LegalizeTF>();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
|
||||
bool run_tfl_runtime_verification) {
|
||||
return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTF> pass(
|
||||
|
@ -31,7 +31,8 @@ namespace {
|
||||
|
||||
// Legalize TF While to TFL While with calls to the original functions from the
|
||||
// cond and body regions.
|
||||
struct LegalizeWhile : public OperationPass<LegalizeWhile, ModuleOp> {
|
||||
struct LegalizeWhile
|
||||
: public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
|
||||
void RunOnFunction(FuncOp func);
|
||||
|
||||
void runOnOperation() override {
|
||||
@ -76,7 +77,7 @@ void LegalizeWhile::RunOnFunction(FuncOp func) {
|
||||
}
|
||||
|
||||
// Creates an instance of the TensorFlow While to TFLite While pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass() {
|
||||
return std::make_unique<LegalizeWhile>();
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,8 @@ namespace {
|
||||
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
|
||||
// defines the op quantization traits, which are used to propagate the
|
||||
// quantization parameters by the following passes.
|
||||
struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
|
||||
struct LoadQuantizationRecipe
|
||||
: public PassWrapper<LoadQuantizationRecipe, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
@ -215,7 +216,7 @@ void LoadQuantizationRecipe::runOnFunction() {
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
|
||||
// pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLoadQuantizationRecipePass() {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLoadQuantizationRecipePass() {
|
||||
return absl::make_unique<LoadQuantizationRecipe>();
|
||||
}
|
||||
|
||||
|
@ -82,7 +82,7 @@ class TensorListPatternRewriter : public PatternRewriter {
|
||||
|
||||
/// Lower TensorList ops in functions for subsequent legalization.
|
||||
struct LowerStaticTensorListPass
|
||||
: public OperationPass<LowerStaticTensorListPass, ModuleOp> {
|
||||
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
|
||||
// Apply type and op changes within a function.
|
||||
@ -720,7 +720,7 @@ struct ConvertTensorListStack
|
||||
RankedTensorType::get({-1}, rewriter.getIntegerType(32));
|
||||
auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
|
||||
SmallVector<int64_t, 8> output_shape = {op.num_elements().getSExtValue()};
|
||||
for (auto dim : dense_elem_attr.getIntValues())
|
||||
for (const auto &dim : dense_elem_attr.getIntValues())
|
||||
output_shape.push_back(dim.getSExtValue());
|
||||
RankedTensorType result_type =
|
||||
RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
|
||||
@ -906,7 +906,8 @@ void LowerStaticTensorListPass::runOnOperation() {
|
||||
|
||||
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
/// pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> TFL::CreateLowerStaticTensorListPass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
TFL::CreateLowerStaticTensorListPass() {
|
||||
return std::make_unique<LowerStaticTensorListPass>();
|
||||
}
|
||||
|
||||
|
@ -74,7 +74,7 @@ bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
|
||||
using ::llvm::cast;
|
||||
|
||||
// Optimize TFLite operations in functions.
|
||||
struct Optimize : public FunctionPass<Optimize> {
|
||||
struct Optimize : public PassWrapper<Optimize, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
@ -650,7 +650,7 @@ struct ConvertTrivialTransposeOpToReshapeOp
|
||||
|
||||
auto input_shape = input_type.getShape();
|
||||
SmallVector<int64_t, 8> perm_values;
|
||||
for (auto dim : perm_values_attr.getIntValues())
|
||||
for (const auto &dim : perm_values_attr.getIntValues())
|
||||
perm_values.push_back(dim.getSExtValue());
|
||||
|
||||
// This should never happen unless the input graph is malformed.
|
||||
@ -725,7 +725,7 @@ void Optimize::runOnFunction() {
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateOptimizePass() {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass() {
|
||||
return std::make_unique<Optimize>();
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ using FuncSet = llvm::SmallSet<FuncOp, 4>;
|
||||
|
||||
// Module pass to optimize TensorFlow functional ops.
|
||||
struct OptimizeFunctionalOpsPass
|
||||
: public OperationPass<OptimizeFunctionalOpsPass, ModuleOp> {
|
||||
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
@ -198,7 +198,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
|
||||
return std::make_unique<OptimizeFunctionalOpsPass>();
|
||||
}
|
||||
|
||||
|
@ -24,75 +24,79 @@ namespace mlir {
|
||||
class FuncOp;
|
||||
class ModuleOp;
|
||||
template <typename T>
|
||||
class OpPassBase;
|
||||
class OperationPass;
|
||||
|
||||
namespace TFL {
|
||||
class QuantizationSpecs;
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFPass();
|
||||
// When the given run_tfl_runtime_verification value is true, it will check each
|
||||
// TFL builtin op towards the TFL runtime capability and the incompatible TF ops
|
||||
// will be left in the graph without getting legalized.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
|
||||
bool run_tfl_runtime_verification);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateOptimizePass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareTFPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
||||
bool unfold_batch_matmul);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
// pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLowerStaticTensorListPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
|
||||
const QuantizationSpecs& quant_specs);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
|
||||
bool emit_quant_adaptor_ops);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||
// pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
|
||||
// pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOphintPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
|
||||
// pass. The composite op is created from the ophint extraction pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeOphintFuncOpPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect pass to add default
|
||||
// quantization parameters.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
|
||||
// tensor to sparse format.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDenseToSparsePass();
|
||||
|
||||
// Creates function pass to legalize TF While to TFL While.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
||||
|
||||
// Verifies runtime supports types used.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass();
|
||||
// Verifies runtime constraints.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
|
@ -30,7 +30,7 @@ namespace TFL {
|
||||
namespace {
|
||||
|
||||
// Applies all the clean up steps after quantization.
|
||||
class PostQuantizePass : public FunctionPass<PostQuantizePass> {
|
||||
class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
|
||||
public:
|
||||
// Constructor used by the PassRegistration. This will remove the adaptor ops.
|
||||
explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
|
||||
@ -135,7 +135,7 @@ void PostQuantizePass::runOnFunction() {
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
|
||||
bool emit_quant_adaptor_ops) {
|
||||
return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
|
||||
}
|
||||
|
@ -94,7 +94,8 @@ class ConvertEmbeddedLookupFunc {
|
||||
// body with the corresponding fused TFLite op. The replacement need not always
|
||||
// be a fused op, though that is the primary use case.
|
||||
class PrepareCompositeFunctionsPass
|
||||
: public OperationPass<PrepareCompositeFunctionsPass, ModuleOp> {
|
||||
: public PassWrapper<PrepareCompositeFunctionsPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit PrepareCompositeFunctionsPass() {}
|
||||
|
||||
@ -211,7 +212,7 @@ void PrepareCompositeFunctionsPass::runOnOperation() {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
|
||||
return std::make_unique<PrepareCompositeFunctionsPass>();
|
||||
}
|
||||
|
||||
|
@ -66,7 +66,8 @@ namespace {
|
||||
// across ops. This step is necessary for post-training quantization and also
|
||||
// making the quantization rule for some operations in the quantization-aware
|
||||
// training quantization simpler.
|
||||
class PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
|
||||
class PrepareQuantizePass
|
||||
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
|
||||
public:
|
||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||
explicit PrepareQuantizePass() {
|
||||
@ -281,7 +282,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
|
||||
const QuantizationSpecs& quant_specs) {
|
||||
return std::make_unique<PrepareQuantizePass>(quant_specs);
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ namespace TFL {
|
||||
namespace {
|
||||
|
||||
// Prepare TF operations in functions for subsequent legalization.
|
||||
class PrepareTFPass : public FunctionPass<PrepareTFPass> {
|
||||
class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||
public:
|
||||
explicit PrepareTFPass() : unfold_batch_matmul_(true) {}
|
||||
explicit PrepareTFPass(bool unfold_batch_matmul)
|
||||
@ -652,7 +652,7 @@ void PrepareTFPass::runOnFunction() {
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareTFPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
||||
bool unfold_batch_matmul) {
|
||||
return std::make_unique<PrepareTFPass>(unfold_batch_matmul);
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ struct TFLFullQuantization
|
||||
};
|
||||
|
||||
// Applies quantization on the model in TFL dialect.
|
||||
struct QuantizePass : public FunctionPass<QuantizePass> {
|
||||
struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
@ -93,7 +93,7 @@ void QuantizePass::runOnFunction() {
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass() {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass() {
|
||||
return std::make_unique<QuantizePass>();
|
||||
}
|
||||
|
||||
|
@ -22,33 +22,32 @@ namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
// This pass verifies that the operands and results types are supported by
|
||||
// TFLite runtime.
|
||||
class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
|
||||
// This pass verifies that the TFL ops meet the TFL runtime constraints.
|
||||
class RuntimeVerifyPass
|
||||
: public mlir::PassWrapper<RuntimeVerifyPass, FunctionPass> {
|
||||
public:
|
||||
explicit RuntimeTypeVerifyPass() {}
|
||||
explicit RuntimeVerifyPass() {}
|
||||
|
||||
private:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void RuntimeTypeVerifyPass::runOnFunction() {
|
||||
void RuntimeVerifyPass::runOnFunction() {
|
||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||
if (failed(op.VerifyTflRuntimeTypes(
|
||||
op.getOperation(),
|
||||
/*failure_on_operand_type_mismatch=*/true)))
|
||||
if (failed(op.VerifyTflRuntimeConstraints(
|
||||
op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
|
||||
signalPassFailure();
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Verifies runtime supports types used.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass() {
|
||||
return std::make_unique<RuntimeTypeVerifyPass>();
|
||||
// Verifies TFL runtime constraints.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass() {
|
||||
return std::make_unique<RuntimeVerifyPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<RuntimeTypeVerifyPass> pass(
|
||||
"tfl-runtime-verify", "TFLite runtime verification");
|
||||
static PassRegistration<RuntimeVerifyPass> pass("tfl-runtime-verify",
|
||||
"TFLite runtime verification");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -66,7 +66,8 @@ namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
struct SplitMergedOperandsPass : public FunctionPass<SplitMergedOperandsPass> {
|
||||
struct SplitMergedOperandsPass
|
||||
: public PassWrapper<SplitMergedOperandsPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
@ -119,7 +120,7 @@ void SplitMergedOperandsPass::runOnFunction() {
|
||||
|
||||
/// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
|
||||
/// pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass() {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass() {
|
||||
return std::make_unique<SplitMergedOperandsPass>();
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ namespace {
|
||||
// The pass to trim functions before we legalize to TFL
|
||||
// dialect using the specified whitelist.
|
||||
class TrimFunctionsPass
|
||||
: public mlir::OperationPass<TrimFunctionsPass, ModuleOp> {
|
||||
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
||||
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
|
||||
@ -120,7 +120,7 @@ void TrimFunctionsPass::Verify() {
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||
/// pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
|
||||
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ namespace {
|
||||
// This pass outlines the cond/body region of the TFL WhileOp into functions and
|
||||
// replaces the regions with calls to these outlined functions.
|
||||
class WhileOutlinePass
|
||||
: public mlir::OperationPass<WhileOutlinePass, ModuleOp> {
|
||||
: public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit WhileOutlinePass() {}
|
||||
|
||||
@ -241,7 +241,7 @@ void WhileOutlinePass::runOnOperation() {
|
||||
}
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass() {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
|
||||
return std::make_unique<WhileOutlinePass>();
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [
|
||||
tool_names = [
|
||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
||||
'mlir-tflite-runner', 'tfcompile'
|
||||
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
@ -1078,6 +1078,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
@ -1118,6 +1119,7 @@ tf_cc_test(
|
||||
srcs = ["utils/compile_mlir_util_test.cc"],
|
||||
deps = [
|
||||
":compile_mlir_util",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
@ -1329,6 +1331,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
|
@ -108,8 +108,6 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
|
||||
>();
|
||||
|
||||
addOperations<ParallelExecuteOp>();
|
||||
|
||||
addInterfaces<TFInlinerInterface>();
|
||||
}
|
||||
|
||||
@ -161,22 +159,8 @@ LogicalResult Verify(ParallelExecuteOp op) {
|
||||
int output_index = 0;
|
||||
for (auto& region_and_index : llvm::enumerate(regions)) {
|
||||
auto& region = region_and_index.value();
|
||||
auto region_index = region_and_index.index();
|
||||
|
||||
// Each region must include a single block of ops and must not be empty.
|
||||
if (region.empty()) {
|
||||
return op.emitOpError()
|
||||
<< "regions must not be empty. "
|
||||
<< "Found an empty region (" << region_index << ").";
|
||||
}
|
||||
|
||||
if (!has_single_element(region)) {
|
||||
return op.emitOpError()
|
||||
<< "regions must be composed of a single block of operations."
|
||||
<< "Expected region (" << region_index << ") with 1 block.";
|
||||
}
|
||||
|
||||
auto* region_terminator = region.front().getTerminator();
|
||||
|
||||
// Check that output types of regions match return operand types.
|
||||
for (auto result_type : region_terminator->getOperandTypes()) {
|
||||
if (result_type !=
|
||||
@ -214,8 +198,6 @@ void ParallelExecuteOp::build(Builder* builder, OperationState& state,
|
||||
state.addTypes(output_types);
|
||||
}
|
||||
|
||||
LogicalResult ParallelExecuteOp::verify() { return Verify(*this); }
|
||||
|
||||
Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) {
|
||||
return getOperation()->getRegion(index).front();
|
||||
}
|
||||
|
@ -43,47 +43,6 @@ class TensorFlowDeviceDialect : public Dialect {
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h.inc"
|
||||
|
||||
// TODO(b/148642767): Use tablegen to define tf_device.parallel_execute op once
|
||||
// variadic regions can be expressed in tablegen.
|
||||
//
|
||||
// ParallelExecute op concurrently executes variadic number of regions. Regions
|
||||
// must represent separate sets of instructions to execute concurrently. In
|
||||
// order to represent concurrently executed regions with dependencies, multiple
|
||||
// ParallelExecute ops can be used instead. As so, regions within
|
||||
// ParallelExecute op must not have control/data dependencies. While explicit
|
||||
// dependencies between regions are disallowed, ParallelExecute op does not
|
||||
// prevent implicit communication between regions (e.g. communication via
|
||||
// send/recvs). In this case, users of ParallelExecute op must provide correct
|
||||
// control dependencies between regions to guarantee correctness. Regions in
|
||||
// ParallelExecute may include Resource ops. In the case where different regions
|
||||
// include ops access the same resource, the users of the ParallelExecute op
|
||||
// must provide mechanism (via send/recvs or via control dependencies) to
|
||||
// guarantee correct ordering. Sequential ordering of ops within a region is
|
||||
// guaranteed. Also, sequential ordering of ops before/after ParallelExecute ops
|
||||
// are guaranteed. That is, execution of regions inside ParallelExecute op is
|
||||
// blocked until all inputs to all regions are materialized and ops following
|
||||
// ParallelExecute op are blocked until all regions are executed.
|
||||
class ParallelExecuteOp
|
||||
: public Op<ParallelExecuteOp,
|
||||
OpTrait::SingleBlockImplicitTerminator<ReturnOp>::Impl> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static void build(Builder* builder, OperationState& state, int num_regions,
|
||||
llvm::ArrayRef<Type> output_types);
|
||||
|
||||
static StringRef getOperationName() { return "tf_device.parallel_execute"; }
|
||||
|
||||
LogicalResult verify();
|
||||
Block& GetRegionBlockWithIndex(unsigned index);
|
||||
Operation::result_range GetRegionOutputs(unsigned region_index);
|
||||
|
||||
// Checks if a tf_device.parallel_execute index'th region block wraps a single
|
||||
// operation and the single operation results are perfectly forwarded to the
|
||||
// region block's return.
|
||||
bool RegionWrapsSingleOp(unsigned index);
|
||||
};
|
||||
|
||||
} // namespace tf_device
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -125,6 +125,55 @@ def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TfDevice_ParallelExecuteOp : TfDevice_Op<"parallel_execute",
|
||||
[SingleBlockImplicitTerminator<"ReturnOp">]> {
|
||||
let description = [{
|
||||
ParallelExecute op concurrently executes variadic number of regions. Regions
|
||||
must represent separate sets of instructions to execute concurrently. In
|
||||
order to represent concurrently executed regions with dependencies, multiple
|
||||
ParallelExecute ops can be used instead. As so, regions within
|
||||
ParallelExecute op must not have control/data dependencies.
|
||||
|
||||
While explicit dependencies between regions are disallowed, ParallelExecute
|
||||
op does not prevent implicit communication between regions (e.g.
|
||||
communication via send/recvs). In this case, users of ParallelExecute op
|
||||
must provide correct control dependencies between regions to guarantee
|
||||
correctness. Regions in ParallelExecute may include Resource ops.
|
||||
|
||||
In the case where different regions include ops access the same resource,
|
||||
the users of the ParallelExecute op must provide mechanism (via send/recvs
|
||||
or via control dependencies) to guarantee correct ordering. Sequential
|
||||
ordering of ops within a region is guaranteed. Also, sequential ordering of
|
||||
ops before/after ParallelExecute ops are guaranteed. That is, execution of
|
||||
regions inside ParallelExecute op is blocked until all inputs to all regions
|
||||
are materialized and ops following ParallelExecute op are blocked until all
|
||||
regions are executed.
|
||||
}];
|
||||
|
||||
let results = (outs
|
||||
Variadic<AnyType>:$execute_outputs
|
||||
);
|
||||
|
||||
let regions = (region VariadicRegion<SizedRegion<1>>:$regions);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Block& GetRegionBlockWithIndex(unsigned index);
|
||||
Operation::result_range GetRegionOutputs(unsigned region_index);
|
||||
|
||||
// Checks if a tf_device.parallel_execute index'th region block wraps a
|
||||
// single operation and the single operation results are perfectly forwarded
|
||||
// to the region block's return.
|
||||
bool RegionWrapsSingleOp(unsigned index);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"Builder* builder, OperationState& state, int num_regions,"
|
||||
"llvm::ArrayRef<Type> output_types">,
|
||||
];
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TfDevice_ReplicateOp :
|
||||
TfDevice_Op<"replicate", [SingleBlockImplicitTerminator<"ReturnOp">]> {
|
||||
let summary = "Wraps an N-way replicated computation.";
|
||||
|
@ -208,7 +208,7 @@ static Type InferReductionOpType(Value input, Value reduction_indices,
|
||||
|
||||
int64_t num_reduce_dim = 0;
|
||||
llvm::SmallVector<bool, 4> is_reduce_dim(rank, false);
|
||||
for (APInt index : indices.getValues<APInt>()) {
|
||||
for (const APInt &index : indices.getValues<APInt>()) {
|
||||
int64_t dim = GetDimForAxis(index.getSExtValue(), rank);
|
||||
// Invalid input.
|
||||
if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty);
|
||||
@ -404,11 +404,11 @@ static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
|
||||
if (perm0.getNumElements() != perm1.getNumElements()) return false;
|
||||
|
||||
SmallVector<int64_t, 8> perm0_values;
|
||||
for (auto value : perm0.getIntValues())
|
||||
for (const auto &value : perm0.getIntValues())
|
||||
perm0_values.push_back(value.getSExtValue());
|
||||
|
||||
SmallVector<int64_t, 8> perm1_values;
|
||||
for (auto value : perm1.getIntValues())
|
||||
for (const auto &value : perm1.getIntValues())
|
||||
perm1_values.push_back(value.getSExtValue());
|
||||
|
||||
for (int i = 0; i < perm0_values.size(); ++i) {
|
||||
@ -2548,12 +2548,15 @@ static LogicalResult Verify(SizeOp op) {
|
||||
// SliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies that,
|
||||
// Verifies that:
|
||||
//
|
||||
// - operands begin and size are 1D with the same number of elements.
|
||||
// - if the input is a ranked tensor, the rank of the input equals the number
|
||||
// of elements in operands begin and size.
|
||||
// - if begin are constants, 0 <= begin[i] < input_ty.getShape()[i]
|
||||
// - if begin are constants, that
|
||||
// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
|
||||
// - if begins aren't constant but the input is a ranked tensor, that
|
||||
// size[i] <= input_ty.getShape()[i]
|
||||
//
|
||||
static LogicalResult Verify(SliceOp op) {
|
||||
RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
|
||||
@ -2587,7 +2590,7 @@ static LogicalResult Verify(SliceOp op) {
|
||||
bool constant_slice_sizes =
|
||||
matchPattern(op.size(), m_Constant(&slice_sizes));
|
||||
int dim = 0;
|
||||
for (APInt raw_begin_index : begin_indices.getValues<APInt>()) {
|
||||
for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
|
||||
int64_t begin_index = raw_begin_index.getSExtValue();
|
||||
int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
|
||||
int64_t slice_size = constant_slice_sizes
|
||||
@ -2603,6 +2606,20 @@ static LogicalResult Verify(SliceOp op) {
|
||||
}
|
||||
++dim;
|
||||
}
|
||||
} else if (input_ty) {
|
||||
// If the inputs are ranked, we can do a few more sanity checks.
|
||||
DenseIntElementsAttr slice_sizes;
|
||||
if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
|
||||
auto input_shape = input_ty.getShape();
|
||||
for (int64_t i = 0; i < input_ty.getRank(); ++i) {
|
||||
int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
|
||||
int64_t input_size = input_shape[i];
|
||||
if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
|
||||
return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
|
||||
"is unknown at compile time";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
@ -3340,7 +3357,7 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x,
|
||||
x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
|
||||
} else {
|
||||
const_shape.reserve(attr_shape.getNumElements());
|
||||
for (auto dim : attr_shape)
|
||||
for (const auto &dim : attr_shape)
|
||||
const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
|
||||
}
|
||||
return TransposeOp::build(
|
||||
|
@ -27,6 +27,7 @@ module {
|
||||
|
||||
// CHECK-LABEL: func @tpu0_func
|
||||
// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK-SAME: sym_visibility = "private"
|
||||
// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]])
|
||||
// CHECK: return %[[TPU0_FUNC_B_OUTPUT]]
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ library {
|
||||
# CHECK: return
|
||||
# CHECK: func @test_func_name0
|
||||
# CHECK-SAME: tf.resource_arg_unique_id = 0
|
||||
# CHECK-SAME tf.resource_arg_unique_id = 0
|
||||
# CHECK-SAME: tf.resource_arg_unique_id = 0
|
||||
# CHECK: tf_executor.graph
|
||||
# CHECK: tf_executor.fetch
|
||||
# CHECK: return
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail
|
||||
// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test that simple TF operations can be inlined.
|
||||
|
||||
|
@ -48,7 +48,7 @@ func @transpose_resnet_layer(%arg0: tensor<?x224x224x3xf32>, // input
|
||||
} : (tensor<?x3x230x230xf32>, tensor<7x7x3x64xf32>) -> tensor<?x64x112x112xf32>
|
||||
|
||||
// CHECK: %[[CONV0:[0-9]*]] = "tf.Conv2D"
|
||||
// CHECK-SAME %[[PAD]]
|
||||
// CHECK-SAME: %[[PAD]]
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
// CHECK-SAME: strides = [1, 2, 2, 1]
|
||||
|
||||
|
@ -163,12 +163,12 @@ func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<
|
||||
}
|
||||
|
||||
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = xla_hlo.pow %arg0, %arg0 : tensor<2xf32>
|
||||
%0 = xla_hlo.power %arg0, %arg0 : tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = xla_hlo.pow %arg0, %arg0 : tensor<?xf32>
|
||||
%0 = xla_hlo.power %arg0, %arg0 : tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -184,7 +184,7 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te
|
||||
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
|
||||
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.neg"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
@ -203,7 +203,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
|
||||
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
|
||||
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.neg"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
|
||||
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
@ -461,32 +461,32 @@ func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
|
||||
}
|
||||
|
||||
func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
@ -551,17 +551,17 @@ func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
}
|
||||
|
||||
func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
@ -577,17 +577,17 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
}
|
||||
|
||||
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.sin"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
@ -677,6 +677,11 @@ func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
|
||||
return %0 : tensor<3xcomplex<f32>>
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NHWC(
|
||||
@ -1481,3 +1486,10 @@ func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
|
||||
// CHECK: [[VAL_366:%.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: return [[VAL_366]] : tensor<i64>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @complex(
|
||||
// CHECK-SAME: [[VAL_367:%.*]]: tensor<3xf32>, [[VAL_368:%.*]]: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
|
||||
// CHECK: [[VAL_369:%.*]] = "tf.Complex"([[VAL_367]], [[VAL_368]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
|
||||
// CHECK: return [[VAL_369]] : tensor<3xcomplex<f32>>
|
||||
// CHECK: }
|
||||
|
||||
|
@ -7,6 +7,7 @@ glob_lit_tests(
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"preserve-entry-func-names.mlir": ["nomac"], # TODO(b/148403706): flaky on Mac, to be fixed.
|
||||
"tf_add.mlir": ["nomac"], # TODO(b/148403706): flaky on Mac, to be fixed.
|
||||
},
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
||||
attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {
|
||||
|
@ -116,8 +116,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK-LABEL: func @shape_from_while_to_cond_body_functions
|
||||
func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<4xf32> {
|
||||
// CHECK "tf.While"
|
||||
// CHECK-SAME (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>)
|
||||
// CHECK: "tf.While"
|
||||
// CHECK-SAME: (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>)
|
||||
%0:3 = "tf.While"(%arg0, %arg1, %arg2) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>)
|
||||
return %0#0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -1682,6 +1682,23 @@ func @testSlice_begin_out_of_bound(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @testSlice_unknown_begin_out_of_bounds(%arg0: tensor<4xi32>, %begins: tensor<1xi64>) -> tensor<3xi32> {
|
||||
%sizes = "tf.Const"() {value = dense<[5]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
// expected-error @+1 {{requires size[i] <= Di, even if begin[i] is unknown at compile time}}
|
||||
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSlice_unknown_begin_in_bounds(%arg0: tensor<4xi32>, %begins: tensor<1xi64>) -> tensor<3xi32> {
|
||||
%sizes = "tf.Const"() {value = dense<[4]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Valid StridedSlice operation.
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
|
@ -188,7 +188,7 @@ func @parallel_execute_single_region() {
|
||||
// Check that a parallel_execute op with empty regions are not allowed.
|
||||
func @parallel_execute_empty_region() {
|
||||
"tf_device.parallel_execute"() ( {
|
||||
// expected-error@-1 {{'tf_device.parallel_execute' op regions must not be empty. Found an empty region (0).}}
|
||||
// expected-error@-1 {{'tf_device.parallel_execute' op region #0 ('regions') failed to verify constraint: region with 1 blocks}}
|
||||
},
|
||||
{
|
||||
tf_device.return
|
||||
|
@ -51,12 +51,12 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Delete recursively dead cycle.
|
||||
|
||||
// CHECK-NOT func @recursively_dead0
|
||||
// CHECK-NOT: func @recursively_dead0
|
||||
func @recursively_dead0() {
|
||||
"some_dialect.call"() { callee = @recursively_dead1 } : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NOT func @recursively_dead1
|
||||
// CHECK-NOT: func @recursively_dead1
|
||||
func @recursively_dead1() {
|
||||
"some_dialect.call"() { callee = @recursively_dead0 } : () -> ()
|
||||
return
|
||||
|
@ -86,3 +86,21 @@ module attributes {tf_saved_model.semantics} {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK-NOT: tf_saved_model.global_tensor
|
||||
"tf_saved_model.global_tensor"() {sym_name = "v", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() {sym_name = "v2", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @"v"}, %arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @"v2"})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
// CHECK: "tf.Const"()
|
||||
%0 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
|
||||
// CHECK: "tf.Const"()
|
||||
%1 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user