Merge branch 'master' into tanhsigmoid_16x8

This commit is contained in:
Elena Zhelezina 2020-04-10 13:34:53 +01:00 committed by GitHub
commit 0f047cd174
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
957 changed files with 11397 additions and 6355 deletions

View File

@ -2,6 +2,10 @@
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
</div>
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
[![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow)
**`Documentation`** |
------------------- |
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) |

View File

@ -223,6 +223,7 @@ tf_cuda_cc_test(
":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@ -371,6 +372,22 @@ tf_cuda_cc_test(
],
)
cc_library(
name = "custom_device_testutil",
testonly = True,
srcs = ["custom_device_testutil.cc"],
hdrs = ["custom_device_testutil.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
@ -381,6 +398,7 @@ tf_cc_test(
":c_api",
":c_api_experimental",
":c_api_test_util",
":custom_device_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
@ -448,6 +466,7 @@ filegroup(
],
exclude = [
"c_api_experimental.cc",
"*c_api_tfrt*",
"*test*",
"*dlpack*",
],

View File

@ -21,6 +21,10 @@ limitations under the License.
#include <string>
#include <vector>
// clang-format off
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
@ -31,12 +35,14 @@ limitations under the License.
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/tf_tensor_internal.h"
#ifdef PLATFORM_GOOGLE
#include "tensorflow/c/eager/c_api_tfrt.h"
#endif
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
@ -676,6 +682,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
status->status = tensorflow::Status::OK();
return new TFE_Context{new tfrt::ContextInterface()};
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
#endif
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
@ -1669,6 +1684,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
};
} // namespace
extern "C" {
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status) {
@ -1679,3 +1696,5 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device));
}
} // extern "C"

View File

@ -515,9 +515,11 @@ typedef struct TFE_CustomDevice {
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
TFE_CustomDevice device,
const char* device_name,
void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,

View File

@ -19,6 +19,10 @@ limitations under the License.
#include <string>
// clang-format off
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
@ -584,9 +588,10 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContext(ctx);
}
void ExecuteAdd(bool async, bool forward_input) {
void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, tfrt);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -651,7 +656,6 @@ void ExecuteAdd(bool async, bool forward_input) {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float result[100 * 100] = {0};
@ -661,12 +665,42 @@ void ExecuteAdd(bool async, bool forward_input) {
for (int i = 0; i < 100 * 100; ++i) {
EXPECT_EQ(2.0f, result[i]);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
TEST(CAPI, ExecuteAdd) {
ExecuteAdd(
/*async=*/false,
/*forward_input*/ false,
/*tfrt*/ false);
}
TEST(CAPI, ExecuteAddAsync) {
ExecuteAdd(
/*async=*/true,
/*forward_input*/ false,
/*tfrt*/ false);
}
TEST(CAPI, ExecuteAddForward) {
ExecuteAdd(
/*async=*/false,
/*forward_input*/ true,
/*tfrt*/ false);
}
TEST(CAPI, ExecuteAddForwardAsync) {
ExecuteAdd(
/*async=*/true,
/*forward_input*/ true,
/*tfrt*/ false);
}
#ifdef PLATFORM_GOOGLE
// TODO(b/153349425): Add add forwarding tests for TFRT
TEST(CAPI, ExecuteAddTfrt) {
ExecuteAdd(
/*async=*/false,
/*forward_input*/ false,
/*tfrt*/ true);
}
#endif
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();

View File

@ -20,129 +20,11 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/custom_device_testutil.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
UnpackTensorHandle(var_value, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
@ -394,5 +274,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status) {
return reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
->tensor;
}
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info) {
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device->delete_device = &DeleteLoggingDevice;
custom_device->execute = &LoggingDeviceExecute;
*device = custom_device;
LoggingDevice* logging_device = new LoggingDevice;
logging_device->arrived_flag = arrived_flag;
logging_device->executed_flag = executed_flag;
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status);
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_

View File

@ -0,0 +1,38 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "parallel_device",
srcs = ["parallel_device.cc"],
hdrs = ["parallel_device.h"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
tf_cc_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,597 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <memory>
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
namespace {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
// A representation of the custom device passed in and out of the TFE custom
// device APIs, providing context about the parallel device to
// ParallelDeviceExecute.
class ParallelDevice {
public:
ParallelDevice(const std::string& name,
const std::vector<std::string>& devices);
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
// tensor, but other operations will implicitly broadcast non-parallel input
// tensors across the ParallelDevice's component devices.
//
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
// causes `Execute` to return non-parallel tensors.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK.
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Implements the parallel case for `Execute`, where all of the outputs of the
// operation are ParallelTensors, and all inputs are either ParallelTensors or
// should be implicitly broadcast. This means the operation is not
// TPUReplicatedInput or TPUReplicatedOutput.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ExecuteParallelOperation(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
const std::string& device_name() const { return device_name_; }
private:
// The name of the parallel device
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
const std::string device_name_;
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
};
// The internal representation of a TFE_TensorHandle placed on a
// ParallelDevice. Contains a tuple of tensors, one on each of the
// `underlying_devices_` of the ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
std::unique_ptr<ParallelTensor> t,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
ParallelDevice::ParallelDevice(const std::string& name,
const std::vector<std::string>& devices)
: device_name_(name),
underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
if (device_name_ == current_device) {
std::string message(absl::StrCat(
"Tried to copy a TensorHandle to its existing device: ", device_name_));
TF_SetStatus(status, TF_INTERNAL, message.c_str());
return nullptr;
}
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel
// tensor.
if (inputs.size() != underlying_devices_.size()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
std::vector<TensorHandlePtr> components;
components.reserve(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
std::string message(absl::StrCat(
"Expected all inputs to TPUReplicatedInput to be non-parallel "
"TensorHandles. The input ",
i,
" was a parallel tensor (already "
"placed on the parallel device)."));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
components.emplace_back(TFE_TensorHandleCopySharingTensor(
absl::get<TFE_TensorHandle*>(inputs[i]), status));
}
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
} else if (operation_name == std::string("TPUReplicatedOutput")) {
// Special-cased operation for un-packing one parallel tensor into
// per-device tensors.
OpPtr op(TFE_NewOp(context, operation_name, status));
TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != underlying_devices_.size()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"Expected the input to "
"TPUReplicatedOutput to be a parallel tensor (placed on the "
"parallel device).");
return result;
}
ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
std::vector<MaybeParallelTensorOwned> outputs;
outputs.reserve(t->num_tensors());
for (int i = 0; i < t->num_tensors(); ++i) {
TensorHandlePtr this_output(
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
outputs.emplace_back(std::move(this_output));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(outputs));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
ExecuteParallelOperation(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value()));
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(parallel_results.size());
for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
result_content.push_back(
MaybeParallelTensorOwned(std::move(parallel_result)));
}
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::ExecuteParallelOperation(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data);
}
TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, t_released->device_.device_name().c_str(), t_released->dtype_,
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
&ParallelTensorDeallocator, nullptr, status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
// registration.
//
// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
// a ParallelTensor with one copy of `tensor` for each device in the
// ParallelDevice.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
std::unique_ptr<ParallelTensor> parallel_tensor(
dev->CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
status)
.release();
}
// For TFE_CustomDevice::copy_tensor_from_device in the parallel device
// registration.
//
// Currently this is an error, and un-packing ParallelTensors must be performed
// explicitly by running a TPUReplicatedOutput operation on the parallel device.
//
// TODO(allenl): There are some use-cases that are only supported by copying to
// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
// need to return something here or address these use-cases one by one.
TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a parallel device.");
return nullptr;
}
// For TFE_CustomDevice::execute in the parallel device registration.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return;
if (dev->device_name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
TFE_TensorHandleDevicePointer(inputs[i], status)));
if (TF_GetCode(status) != TF_OK) return;
} else {
typed_inputs.emplace_back(inputs[i]);
}
}
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
*num_outputs, status));
if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
return;
}
std::vector<MaybeParallelTensorOwned> typed_outputs(
std::move(maybe_typed_outputs.value()));
if (typed_outputs.size() > *num_outputs) {
TF_SetStatus(status, TF_INTERNAL,
"The allocated output buffer was too small.");
return;
}
for (int i = 0; i < typed_outputs.size(); ++i) {
MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else {
outputs[i] = ParallelTensor::AsTensorHandle(
context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)),
status)
.release();
if (TF_GetCode(status) != TF_OK) return;
}
}
*num_outputs = typed_outputs.size();
}
// For TFE_CustomDevice::delete_device in the parallel device registration.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<ParallelDevice*>(device_info);
}
} // namespace
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToParallelDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice;
custom_device.delete_device = &DeleteParallelDevice;
custom_device.execute = &ParallelDeviceExecute;
std::vector<std::string> underlying_devices_vector;
underlying_devices_vector.reserve(num_underlying_devices);
for (int device_index = 0; device_index < num_underlying_devices;
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
ParallelDevice* d =
new ParallelDevice(device_name, underlying_devices_vector);
TFE_RegisterCustomDevice(context, custom_device, device_name, d, status);
}
} // namespace eager
} // namespace tensorflow

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
namespace eager {
// Register a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
// on each underlying device.
//
// For example if `device_name` is
// "/job:localhost/replica:0/task:0/device:CUSTOM:0"
// and `underlying_devices` is
// {"/job:localhost/replica:0/task:0/device:GPU:0",
// "/job:localhost/replica:0/task:0/device:GPU:1"}
// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1.
//
// Implicit copies onto `device_name` are allowed, replicating the value once
// per device in `underlying_devices`. Implicit copies off of the device throw
// an error.
//
// All component tensors must have the same dtype. Currently they must also have
// the same shape, although this requirement may be relaxed in the future.
//
// `device_name` must not name an existing physical or custom device (see
// the documentation for TFE_RegisterCustomDevice for more information).
//
// Tensors may be copied on or off the device explicitly using
// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with
// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the
// parallel device creates a parallel tensor `x` with `a` on the first of
// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked =
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
// into its components.
//
// `context` owns the parallel device. `underlying_devices` must stay valid
// while the parallel device is in use.
void RegisterParallelDevice(TFE_Context* context, const char* device_name,
const char** underlying_devices,
int num_underlying_devices, TF_Status* status);
} // namespace eager
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_

View File

@ -0,0 +1,917 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
// integration testing rather than purely testing the parallel device. They
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Assert that `handle` is equal to `expected_value`.
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(expected_value,
*static_cast<float*>(TF_TensorData(value_zero.get())));
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
tensorflow::eager::RegisterParallelDevice(
context, device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 20.);
AssertScalarFloatEq(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 23.);
AssertScalarFloatEq(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1");
}
TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:0");
}
TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Skip the test if no TPU is available.
std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool has_tpu = false;
for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
++device_index) {
std::string device_type =
TF_DeviceListType(devices.get(), device_index, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
if (device_type == "TPU") {
has_tpu = true;
break;
}
}
if (has_tpu) {
BasicTestsForTwoDevices(context.get(),
"/job:localhost/replica:0/task:0/device:TPU:0",
"/job:localhost/replica:0/task:0/device:TPU:0");
}
}
TEST(PARALLEL_DEVICE, TestExplicitCopies) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CPU:0";
underlying_devices.push_back(first_device_name);
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CPU:1";
underlying_devices.push_back(second_device_name);
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Copying on to a parallel device is OK.
TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
cpu_value.get(), context.get(), device_name, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* backing_device =
TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(std::string(device_name), backing_device);
// Un-pack the parallel tensor to verify that the copy was successful.
{
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context.get(), device_value.get(), &components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// The value of the original tensor is replicated on each device.
AssertScalarFloatEq(components[0].get(), 3.);
AssertScalarFloatEq(components[1].get(), 3.);
// Verify that the mirrors are placed on the component devices.
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Copies off of parallel devices must be explicit.
TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
device_value.get(), context.get(), first_device_name, status.get()));
ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL);
}
TEST(PARALLEL_DEVICE, TestDifferentShapes) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create two vectors with different lengths
std::vector<float> size_two_value{1., 2.};
std::vector<float> size_three_value{1., 2., 3.};
TensorHandlePtr size_two(
VectorFloatTensorHandle(size_two_value, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr size_three(
VectorFloatTensorHandle(size_three_value, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Try to combine these values into a single parallel tensor.
std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED)
<< TF_Message(status.get());
}
TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
3),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a parallel device with two CPUs
const char* first_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> first_underlying_devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
tensorflow::eager::RegisterParallelDevice(
context.get(), first_device_name, first_underlying_devices.data(),
first_underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a second parallel device with the first parallel device and one
// additional CPU.
const char* second_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
std::vector<const char*> second_underlying_devices{
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
"/job:localhost/replica:0/task:0/device:CPU:2"};
tensorflow::eager::RegisterParallelDevice(
context.get(), second_device_name, second_underlying_devices.data(),
second_underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the first parallel device
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr first_combined_value = CreatePerDeviceValues(
context.get(), components, first_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Nest the first parallel tensor into a second
TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
components[0] = first_combined_value.get();
components[1] = value_three.get();
TensorHandlePtr second_combined_value = CreatePerDeviceValues(
context.get(), components, second_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr multiply_result(Multiply(context.get(),
second_combined_value.get(),
negative_one.get(), status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Un-pack the parallel tensor to verify that the operation was
// successful. The resulting structure should be:
// second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
std::array<TensorHandlePtr, 2> second_components;
ExtractPerDeviceValues(context.get(), multiply_result.get(),
&second_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(second_components[1].get(), 9.);
// Verify that the mirrors are placed on the component devices.
std::string first_device = TFE_TensorHandleBackingDeviceName(
second_components[0].get(), status.get());
ASSERT_EQ(second_underlying_devices[0], first_device);
std::string second_device = TFE_TensorHandleBackingDeviceName(
second_components[1].get(), status.get());
ASSERT_EQ(second_underlying_devices[1], second_device);
// Un-pack the first parallel device's tensor too
std::array<TensorHandlePtr, 2> first_components;
ExtractPerDeviceValues(context.get(), second_components[0].get(),
&first_components, status.get());
AssertScalarFloatEq(first_components[0].get(), 3.);
AssertScalarFloatEq(first_components[1].get(), 6.);
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
status.get());
ASSERT_EQ(first_underlying_devices[0], first_device);
second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
status.get());
ASSERT_EQ(first_underlying_devices[1], second_device);
}
TEST(PARALLEL_DEVICE, TestInvalidPacking) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
{
// Try to pack two TensorHandles onto a parallel device with a single
// component.
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to extract the wrong number of components from a parallel tensor
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), correct_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> incorrect_components;
ExtractPerDeviceValues(context.get(), combined_value.get(),
&incorrect_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to pass a ParallelTensor to TPUReplicatedInput
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), correct_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
TensorHandlePtr recombined_value = CreatePerDeviceValues(
context.get(), incorrect_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
{
// Try to pass a non-parallel tensor to TPUReplicatedOutput
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
TFE_DeleteOp);
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
TFE_OpAddInput(op.get(), value_one.get(), status.get());
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_OpSetDevice(op.get(), device_name, status.get());
if (TF_GetCode(status.get()) != TF_OK) return;
TFE_TensorHandle* result_handles;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
}
}
TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
int group_size, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
TFE_OpSetAttrInt(op.get(), "group_size", group_size);
TFE_OpSetAttrInt(op.get(), "group_key", 0);
TFE_OpSetAttrInt(op.get(), "instance_key", 0);
const std::string merge_op("Add");
TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
merge_op.length());
const std::string final_op("Id");
TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
final_op.length());
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TEST(PARALLEL_DEVICE, TestCollective) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a tensor on the parallel device
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr parallel_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Run a collective sum, so each component should now be the same.
TensorHandlePtr reduced(
CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> result_components;
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 3.);
AssertScalarFloatEq(result_components[1].get(), 3.);
}
void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size,
TF_Status* status) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
TF_DeleteGraph);
TF_OperationDescription* placeholder_desc =
TF_NewOperation(body.get(), "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
if (TF_GetCode(status) != TF_OK) return;
TF_Output x{placeholder_op, 0};
TF_OperationDescription* reduce_desc =
TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
TF_SetAttrInt(reduce_desc, "group_size", group_size);
TF_SetAttrInt(reduce_desc, "group_key", 0);
TF_SetAttrInt(reduce_desc, "instance_key", 0);
const std::string merge_op("Mul");
TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
merge_op.length());
const std::string final_op("Id");
TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
final_op.length());
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
TF_AddInput(reduce_desc, x);
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
if (TF_GetCode(status) != TF_OK) return;
TF_Operation* operations[]{placeholder_op, reduce_op};
TF_Output y{reduce_op, 0};
const char* output_name = "y";
std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
TF_GraphToFunction(
/* fn_body */ body.get(), /* fn_name */ function_name,
/* append_hash_to_fn_name */ 0, /* num_opers */ 2,
/* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
/* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
/* opts */ nullptr, /* description */ "", /* status */ status),
TF_DeleteFunction);
if (TF_GetCode(status) != TF_OK) return;
TFE_ContextAddFunction(context, function.get(), status);
}
TEST(PARALLEL_DEVICE, TestFunction) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::vector<const char*> underlying_devices;
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0");
underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1");
tensorflow::eager::RegisterParallelDevice(
context.get(), device_name, underlying_devices.data(),
underlying_devices.size(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* function_name = "test_reduce_mul";
RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
TensorHandlePtr parallel_value = CreatePerDeviceValues(
context.get(), components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* raw_result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr reduced(raw_result_handle);
std::array<TensorHandlePtr, 2> result_components;
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
std::string first_device = TFE_TensorHandleBackingDeviceName(
result_components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device = TFE_TensorHandleBackingDeviceName(
result_components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}

View File

@ -16,6 +16,12 @@ cc_library(
deps = ["//tensorflow/core:test_main"],
)
filegroup(
name = "quantize_header",
srcs = ["quantize.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "tfcompile_lib",
srcs = [
@ -27,6 +33,7 @@ cc_library(
"codegen.h",
"compile.h",
"flags.h",
"quantize.h",
],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
@ -37,7 +44,6 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "tensorflow/compiler/aot/quantize.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@ -46,6 +46,14 @@ limitations under the License.
namespace tensorflow {
namespace tfcompile {
static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
if (*quantize_xla) return false;
*quantize_xla = fn;
return true;
}
namespace {
// Compiles the XLA computation into executable code.
@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
} else {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
if (flags.experimental_quantize) {
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
if (flags.experimental_quantize && *quantize_xla) {
TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());

View File

@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
#include <functional>
#include <iostream>
#include <ostream>
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
namespace mlir {
namespace xla_hlo {
namespace tensorflow {
namespace tfcompile {
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation);
using QuantizeXlaFn = std::function<Status(const tf2xla::Config& config,
xla::XlaComputation* computation)>;
} // namespace xla_hlo
} // namespace mlir
// Set the static quantization function to the `fn` if it hasn't been set.
// Return false if the static function has been set.
bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_

View File

@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp(
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
}
GraphDebugInfo debug_info;
return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
compile_options.use_tuple_arg,
*options.flib_def, debug_info,
options.shape_representation_fn, result);
return CompileGraphToXlaHlo(
*graph, {arg_shapes.data(), arg_shapes.size()},
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -58,6 +58,11 @@ cc_library(
"//tensorflow/python:__subpackages__",
],
deps = [
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:QuantOps",
# Link jit lib to link JIT devices required to run
# xla-legalize-tf-with-tf2xla pass.
"//tensorflow/compiler/jit",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
@ -65,7 +70,6 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
@ -90,8 +94,6 @@ cc_library(
"//tensorflow/compiler/mlir/xla:xla_lower",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
"//tensorflow/compiler/mlir/xla:xla_test_passes",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:QuantOps",
],
)

View File

@ -307,7 +307,7 @@ cc_library(
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/runtime_type_verify.cc",
"transforms/runtime_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/while_loop_outline.cc",
@ -537,6 +537,15 @@ tf_native_cc_binary(
],
)
tf_native_cc_binary(
name = "json_to_flatbuffer",
srcs = ["json_to_flatbuffer.cc"],
deps = [
"//tensorflow/lite/schema:schema_fbs",
"@flatbuffers",
],
)
cc_library(
name = "emit_error_reporter",
srcs = [

View File

@ -36,7 +36,8 @@ struct PassConfig {
form_clusters(false),
unfold_batch_matmul(true),
legalize_tf_while(true),
shape_inference(true) {}
shape_inference(true),
runtime_verification(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@ -65,6 +66,8 @@ struct PassConfig {
bool legalize_tf_while;
// Whether to do shape inference.
bool shape_inference;
// Whether to do TFLite runtime verification.
bool runtime_verification;
};
} // namespace TFL

View File

@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
@ -466,6 +466,25 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
for (auto &trait : op.getTraits()) {
if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
continue;
}
if (trait.getDef().getValueAsString("trait") !=
"OpTrait::TFLRuntimeOpTrait") {
continue;
}
auto *val = trait.getDef().getValue("tflRuntimePredicate");
if (!val) continue;
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt(
" if (!($0)) {\n "
" return ::mlir::LogicalResult::Failure;\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
}
os << " return top.verify();\n}\n";
}

View File

@ -16,6 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
// tfl.abs
template <>
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.add
template <>
class TFLiteCostEstimator<AddOp, hardware::GPU> {
@ -149,6 +162,19 @@ class TFLiteCostEstimator<HardSwishOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.log
template <>
class TFLiteCostEstimator<LogOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.logistic
template <>
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
@ -240,6 +266,32 @@ class TFLiteCostEstimator<PadOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.pow
template <>
class TFLiteCostEstimator<PowOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.prelu
template <>
class TFLiteCostEstimator<PReluOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.relu
template <>
class TFLiteCostEstimator<ReluOp, hardware::GPU> {

View File

@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeTypes",
"LogicalResult", "VerifyTflRuntimeConstraints",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
>,
];

View File

@ -46,6 +46,30 @@ namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL {
// Returns true when the given two types have the same shape or broadcastable
// shape within the given rank. If any given shapes are non-static, this method
// returns true.
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
int max_bcast_rank) {
// Ignore shape checking on the non-static shapes for model compatibility.
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
return true;
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
rhs_shaped_type.getShape(),
result_shape)) {
return false;
}
return lhs_shaped_type.getRank() <= max_bcast_rank &&
rhs_shaped_type.getRank() <= max_bcast_rank;
}
//===----------------------------------------------------------------------===//
// TensorFlowLiteDialect
//===----------------------------------------------------------------------===//
@ -316,7 +340,7 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
const int num_elements = result_shape_type.getNumElements();
new_values.reserve(num_elements);
for (APFloat old_value : dense_elements.getValues<APFloat>()) {
for (const APFloat &old_value : dense_elements.getValues<APFloat>()) {
new_values.push_back(calculate(old_value));
}
@ -844,7 +868,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
if (!shape_elements) return nullptr;
SmallVector<int64_t, 4> shape_data;
for (auto it : shape_elements.getValues<APInt>()) {
for (const auto &it : shape_elements.getValues<APInt>()) {
shape_data.push_back(it.getSExtValue());
}
result_type =
@ -1798,7 +1822,7 @@ static LogicalResult Verify(TransposeOp op) {
int index = 0;
llvm::SmallVector<int64_t, 4> axes;
for (auto axis_int : perm.getValues<APInt>()) {
for (const auto &axis_int : perm.getValues<APInt>()) {
const int64_t axis = axis_int.getSExtValue();
if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) {
return op.emitOpError(

View File

@ -106,6 +106,22 @@ class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
class DerivedTFLiteTypeAttr<code body> :
DerivedAttr<"tflite::TensorType", body>;
// TFL Runtime op trait predicate.
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
GenInternalOpTrait<"TFLRuntimeOpTrait"> {
Pred tflRuntimePredicate = pred;
string tflRuntimeDescription = desc;
}
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
int i, int j, int max_bcast_rank> :
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
" have the same shape or broadcastable shapes within the rank " #
max_bcast_rank,
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
").getType(), " # max_bcast_rank # ")">>;
// These additional types/type constraints here are used to decouple the ops
// from runtime support for the ops. Prefer to use these types when defining
// new TF_Ops for uniformity.
@ -344,7 +360,10 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
// TFL op definitions.
//===----------------------------------------------------------------------===//
def TFL_AbsOp : TFL_Op<"abs", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Absolute value operator";
let description = [{
@ -360,10 +379,9 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1;
}
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
NoSideEffect,
Commutative,
TFL_GpuTargetOp]> {
def TFL_AddOp : TFL_Op<"add", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
let summary = "Addition operator";
let description = [{
@ -371,11 +389,11 @@ def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs,
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;
@ -1527,7 +1545,10 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
}
def TFL_LogOp: TFL_Op<"log", [
NoSideEffect, SameOperandsAndResultType, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Natural logarithm operator";
let description = [{
@ -2072,7 +2093,10 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
let hasOptions = 1;
}
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Power operator";
let description = [{
@ -2092,7 +2116,7 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, NoSideEffect, NoQuanti
let builders = [TFL_BroadcastableBinaryBuilder];
}
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect]> {
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Parameterized Relu operator";
let description = [{

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include <cstddef>
#include <cstdio>
#include <iostream>
#include <string>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "flatbuffers/idl.h" // from @flatbuffers
#include "flatbuffers/util.h" // from @flatbuffers
#include "tensorflow/lite/schema/schema_generated.h"
int main(int argc, char** argv) {
// load FlatBuffer schema (.fbs) and JSON from disk
if (argc < 2) {
std::cerr << "Missing input argument. Usage:\n"
<< argv[0] << " <filename or - for stdin>\n\n";
return 1;
}
const char* schema_path = argv[1];
const char* json_path = argv[2];
std::string schema;
std::string json;
const bool status =
flatbuffers::LoadFile(schema_path, /*binary=*/false, &schema) &&
flatbuffers::LoadFile(json_path, /*binary=*/false, &json);
if (!status) {
std::cerr << "couldn't load files!\n";
return 1;
}
// parse schema first, so we can use it to parse the data after
flatbuffers::Parser parser;
const bool schema_parse_result =
parser.Parse(schema.c_str()) && parser.Parse(json.c_str());
if (!schema_parse_result) {
std::cerr << "Parse error.\n";
return 1;
}
const size_t length = parser.builder_.GetSize();
const size_t n =
std::fwrite(parser.builder_.GetBufferPointer(), 1, length, stdout);
if (n != length) {
std::cerr << "print to stdout filed.\n";
return 1;
}
return 0;
}

View File

@ -88,7 +88,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
pass_config.shape_inference = false;
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
pass_config, result);

View File

@ -16,9 +16,12 @@ limitations under the License.
#include <utility>
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
@ -41,6 +44,77 @@ limitations under the License.
namespace tensorflow {
Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags,
mlir::OwningModuleRef* module) {
mlir::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::FuncOp>()) {
if (auto tf_attrs =
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
// TODO(jaesung): There could be multiple entry functions. Let's handle
// such cases if there are any needs for that.
if (entry_function != nullptr) {
return errors::InvalidArgument(
"There should be only one tf.entry_function");
}
entry_function = func;
}
}
if (entry_function == nullptr) {
return errors::InvalidArgument("no tf.entry_function found");
}
// Get the list of input Op names from the function attribute.
mlir::DictionaryAttr tf_attrs =
entry_function.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
llvm::SmallVector<llvm::StringRef, 4> function_input_names;
function_input_names.reserve(model_flags.input_arrays().size());
auto input_attr = tf_attrs.get("inputs");
if (!input_attr) {
return errors::InvalidArgument("no inputs attribute found");
}
auto input_names = input_attr.cast<mlir::StringAttr>().getValue();
input_names.split(function_input_names, ",");
if (function_input_names.size() != model_flags.input_arrays().size()) {
return errors::InvalidArgument(
"input array size mismatch: got ", function_input_names.size(),
", expected: ", model_flags.input_arrays().size());
}
llvm::StringSet<> function_input_names_set;
function_input_names_set.insert(function_input_names.begin(),
function_input_names.end());
for (const auto& input_array : model_flags.input_arrays()) {
if (function_input_names_set.count(input_array.name()) == 0) {
return errors::InvalidArgument("input array name (", input_array.name(),
") does not exist in the given graph");
}
}
// Get the list of output Op names from the function attribute.
llvm::SmallVector<llvm::StringRef, 4> function_output_names;
function_output_names.reserve(model_flags.output_arrays().size());
auto output_attr = tf_attrs.get("outputs");
if (!output_attr) {
return errors::InvalidArgument("no outputs attribute found");
}
auto output_names = output_attr.cast<mlir::StringAttr>().getValue();
output_names.split(function_output_names, ",");
if (function_output_names.size() != model_flags.output_arrays().size()) {
return errors::InvalidArgument(
"output array size mismatch: got ", function_output_names.size(),
", expected: ", model_flags.output_arrays().size());
}
llvm::StringSet<> function_output_names_set;
function_output_names_set.insert(function_output_names.begin(),
function_output_names.end());
for (const auto& output_array : model_flags.output_arrays()) {
if (function_output_names_set.count(output_array) == 0) {
return errors::InvalidArgument("output array name (", output_array,
") does not exist in the given graph");
}
}
return Status::OK();
}
Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
string* result) {
@ -77,11 +151,15 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
model_flags.saved_model_version(), tags,
exported_names, &context));
if (!model_flags.input_arrays().empty() ||
!model_flags.output_arrays().empty()) {
TF_RETURN_IF_ERROR(HandleInputOutputArraysWithModule(model_flags, &module));
}
mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
pass_config.shape_inference = true;
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, result);

View File

@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,

View File

@ -14,7 +14,10 @@ package(
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = ["//tensorflow/compiler/mlir/..."],
packages = [
"//learning/brain/experimental/mlir/quantization/...",
"//tensorflow/compiler/mlir/...",
],
)
exports_files([

View File

@ -55,7 +55,8 @@ namespace quant {
using QuantParamsEntry = QuantizationInfo::QuantParams;
namespace {
class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
class ImportQuantStatsPass
: public PassWrapper<ImportQuantStatsPass, FunctionPass> {
public:
explicit ImportQuantStatsPass(OperationToName op_to_name)
: op_to_name_(op_to_name) {}
@ -193,7 +194,7 @@ void ImportQuantStatsPass::runOnFunction() {
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string &stats_str) {
auto pass = absl::make_unique<ImportQuantStatsPass>(op_to_name);
if (pass->ParseQuantStats(stats_str)) return nullptr;
@ -203,7 +204,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) {
auto get_name_func = [](Operation *op) {
Location loc = op->getLoc();

View File

@ -27,13 +27,13 @@ using OperationToName = std::function<llvm::StringRef(Operation* op)>;
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>> CreateImportQuantStatsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateImportQuantStatsPass(
OperationToName op_to_name, const std::string& stats_str);
// Creates an instance pass to import quantization stats to the operations in
// the function. A custom method to get the name from the op is used because
// different dialect ops might have different ways to assign the name.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str);
} // namespace quant

View File

@ -79,7 +79,7 @@ TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
SmallVector<double, 4> new_scales;
new_scales.reserve(scales.size());
auto scales_iter = scales.begin();
for (auto f : factor_values) {
for (const auto& f : factor_values) {
new_scales.push_back(*(scales_iter++) *
std::fabs(FloatAttr::getValueAsDouble(f)));
}

View File

@ -25,7 +25,7 @@ namespace mlir {
namespace TF {
// Legalize the tf ops to the quant ops, so the quantization passes can work.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass();
} // namespace TF
} // namespace mlir

View File

@ -27,7 +27,7 @@ namespace TF {
namespace {
// Legalize TF quantization emulation ops to that in Quant ops dialect.
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
struct LegalizeTFToQuant : public PassWrapper<LegalizeTFToQuant, FunctionPass> {
explicit LegalizeTFToQuant() = default;
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
@ -151,7 +151,7 @@ void LegalizeTFToQuant::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFToQuantPass() {
return std::make_unique<LegalizeTFToQuant>();
}

View File

@ -1,112 +0,0 @@
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/aot/...",
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
],
)
cc_library(
name = "hlo_xla_quantization_passes",
srcs = [
"cpu_kernel_fusion.cc",
"generated_cpu_kernel_fusion.inc",
"materialize.cc",
"op_quant_spec.inc",
"propagate.cc",
],
hdrs = [
"passes.h",
],
deps = [
":cpu_device_target",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/xla/client/lib:quantize",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
)
cc_library(
name = "cpu_device_target",
srcs = [
"cpu_device_target.cc",
],
hdrs = [
"cpu_device_target.h",
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:device_target",
"//tensorflow/compiler/mlir/lite/quantization:quantization_context",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "quantize",
srcs = [
"quantize.cc",
],
hdrs = [
"quantize.h",
],
deps = [
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core/platform:status",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
)
gentbl(
name = "cpu_kernel_fusion_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"generated_cpu_kernel_fusion.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "cpu_kernel_fusion.td",
td_srcs = [
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
],
)

View File

@ -1,67 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
namespace mlir {
namespace xla_hlo {
namespace ph = std::placeholders;
CpuDeviceTarget::CpuDeviceTarget(MLIRContext* ctx) : DeviceTarget(ctx) {
RegisterKernel("generic.concat", {qi8_, qi8_, qi8_},
quant::ScaleConstraintType::OutputInputSameScale);
// TODO(fengliuai): All the combinations are required to list. We need to
// improve this.
RegisterKernel("generic.reshape", {qi8_, any_},
quant::ScaleConstraintType::OutputInputSameScale);
RegisterKernel("generic.reshape", {any_, qi8_},
quant::ScaleConstraintType::OutputInputSameScale);
RegisterKernel("generic.mul", {qi8_, qi8_, qi8_},
quant::ScaleConstraintType::OutputInputFreeScale);
RegisterKernel("generic.mul_add", {qi8_, qi8n_, any_, qi8_},
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
this, ph::_1, ph::_2, ph::_3, ph::_4));
RegisterKernel("generic.matmul_add", {qi8_, qi8n_, any_, qi8_},
std::bind(&CpuDeviceTarget::HandleMultiplyAccumulateScale,
this, ph::_1, ph::_2, ph::_3, ph::_4));
}
LogicalResult CpuDeviceTarget::HandleMultiplyAccumulateScale(
quant::QuantizeContext* ctx, Operation* op,
quant::AdjacentOperations* new_items, bool* changed) {
auto bias_params = ctx->GetOperandParams(op, 2);
if (!EmptyParams(bias_params)) {
return success();
}
std::vector<quant::QuantParams> op_types{ctx->GetOperandParams(op, 0),
ctx->GetOperandParams(op, 1)};
auto bias_scale = GetUniformQuantizedTypeForBias(op_types);
if (bias_scale && ctx->SetOperandParams(op, 2, bias_scale)) {
*changed = true;
new_items->push_back(op->getOperand(2).getDefiningOp());
}
return success();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,40 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/device_target.h"
namespace mlir {
namespace xla_hlo {
// Target specs for cpu kernels
class CpuDeviceTarget : public quant::DeviceTarget {
public:
explicit CpuDeviceTarget(MLIRContext* ctx);
private:
LogicalResult HandleMultiplyAccumulateScale(
quant::QuantizeContext* ctx, Operation* op,
quant::AdjacentOperations* new_items, bool* changed);
};
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_CPU_DEVICE_TARGET_H_

View File

@ -1,346 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <math.h>
#include <algorithm>
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <numeric>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
#define DEBUG_TYPE "quant-kernel-fusion"
constexpr int kFakeQuantOperandsNum = 5;
constexpr int kFakeQuantPerChannelOperandsNum = 6;
namespace mlir {
namespace xla_hlo {
namespace {
TypeAttr GetQuantSpec(Operation* op) {
auto fake_quant = llvm::dyn_cast_or_null<CustomCallOp>(op);
if (!fake_quant || fake_quant.getNumOperands() < kFakeQuantOperandsNum ||
fake_quant.getNumOperands() > kFakeQuantPerChannelOperandsNum ||
fake_quant.call_target_name() != "fake_quant_with_min_max_vars")
return {};
DenseFPElementsAttr min, max;
DenseIntElementsAttr bit_width, narrow_range, quant_dim;
if (!matchPattern(fake_quant.getOperand(1), m_Constant(&min)) ||
!matchPattern(fake_quant.getOperand(2), m_Constant(&max)) ||
!matchPattern(fake_quant.getOperand(3), m_Constant(&bit_width)) ||
!matchPattern(fake_quant.getOperand(4), m_Constant(&narrow_range)))
return {};
auto bit_width_val = (*bit_width.attr_value_begin()).cast<IntegerAttr>();
auto narrow_range_val = (*narrow_range.int_value_begin()).getSExtValue();
int quant_dim_val = -1;
if (fake_quant.getNumOperands() == kFakeQuantPerChannelOperandsNum &&
matchPattern(fake_quant.getOperand(kFakeQuantPerChannelOperandsNum - 1),
m_Constant(&quant_dim))) {
quant_dim_val = (*quant_dim.int_value_begin()).getSExtValue();
}
OpBuilder builder(op);
Type input_type =
fake_quant.getOperand(0).getType().cast<ShapedType>().getElementType();
return quant::GetQuantizedTypeAttr(
builder, input_type, min, max, quant_dim_val, bit_width_val,
builder.getBoolAttr(narrow_range_val), /*is_signed=*/true);
}
// Collects input values from outside for 'ops'.
void CollectInputs(llvm::ArrayRef<Operation*> ops,
llvm::SmallVectorImpl<Value>* inputs,
llvm::SmallVectorImpl<Attribute>* input_specs) {
for (Operation* op : ops) {
for (Value operand : op->getOperands()) {
if (std::find(inputs->begin(), inputs->end(), operand) != inputs->end()) {
continue;
}
if (Operation* def_op = operand.getDefiningOp()) {
if (std::find(ops.begin(), ops.end(), def_op) == ops.end()) {
inputs->push_back(operand);
}
} else { // argument value
inputs->push_back(operand);
}
}
}
for (Value input : *inputs) {
ShapedType input_type = input.getType().cast<ShapedType>();
if (TypeAttr spec = GetQuantSpec(input.getDefiningOp())) {
input_specs->push_back(spec);
} else {
input_specs->push_back(TypeAttr::get(input_type.getElementType()));
}
}
}
// Collects values that are produced by 'ops' and have use outside of 'ops'.
// TODO(fengliuai): if it is a single user and QDQ, write that to the specs.
void CollectRets(llvm::ArrayRef<Operation*> ops,
llvm::SmallVectorImpl<Value>* rets,
llvm::SmallVectorImpl<Type>* ret_types,
llvm::SmallVectorImpl<Attribute>* ret_specs) {
for (Operation* op : ops) {
// The constant will not be shared outside the region.
if (llvm::isa<ConstantOp>(op)) continue;
for (Value result : op->getResults()) {
for (Operation* user : result.getUsers()) {
// If there are any user outside of 'ops'
if (std::find(ops.begin(), ops.end(), user) == ops.end()) {
ShapedType ret_type = result.getType().cast<ShapedType>();
rets->push_back(result);
ret_types->push_back(ret_type);
if (TypeAttr spec = GetQuantSpec(user)) {
ret_specs->push_back(spec);
} else {
ret_specs->push_back(TypeAttr::get(ret_type.getElementType()));
}
break;
}
}
}
}
}
enum FusedActivationFunc { NONE, RELU, RELU1, RELU6 };
#define FLOAT_EQ(value, x) fabs(value - x) <= 1e-6
// If the op is max(in, 0.0), we consider this is from Relu, so both this op
// and constant 0.0 will be fused.
// If the op is clamp(0.0, in, 1.0) or clamp(0.0, in, 6.0), we consider this is
// from Relu1 or Relu6, so all the constants and this op will be fused.
// Returns the activation function type.
FusedActivationFunc FuseReluX(Operation* op,
llvm::SmallVectorImpl<Operation*>* fused) {
if (auto max = llvm::dyn_cast<xla_hlo::MaxOp>(op)) {
Value min_val = max.rhs();
llvm::SmallVector<Operation*, 4> broadcast_ops;
if (auto broadcast = llvm::dyn_cast_or_null<xla_hlo::BroadcastInDimOp>(
min_val.getDefiningOp())) {
min_val = broadcast.operand();
broadcast_ops.push_back(broadcast);
}
DenseFPElementsAttr min;
if (!matchPattern(min_val, m_Constant(&min))) {
// In case the min value is lhs.
min_val = max.lhs();
broadcast_ops.clear();
if (auto broadcast = llvm::dyn_cast_or_null<xla_hlo::BroadcastInDimOp>(
min_val.getDefiningOp())) {
min_val = broadcast.operand();
broadcast_ops.push_back(broadcast);
}
if (!matchPattern(min_val, m_Constant(&min))) {
return NONE;
}
}
if (!min.isSplat() ||
!(FLOAT_EQ(min.getSplatValue().cast<FloatAttr>().getValueAsDouble(),
0.0))) {
return NONE;
}
// Include the constant 0.0 as well, to avoid being quantized.
fused->push_back(min_val.getDefiningOp());
fused->append(broadcast_ops.begin(), broadcast_ops.end());
fused->push_back(max);
return RELU;
}
if (auto clamp = llvm::dyn_cast<xla_hlo::ClampOp>(op)) {
DenseFPElementsAttr lower, upper;
if (!matchPattern(clamp.min(), m_Constant(&lower)) ||
!matchPattern(clamp.max(), m_Constant(&upper)) || !lower.isSplat() ||
!upper.isSplat() ||
!(FLOAT_EQ(lower.getSplatValue().cast<FloatAttr>().getValueAsDouble(),
0.0))) {
return NONE;
}
double upper_value =
upper.getSplatValue().cast<FloatAttr>().getValueAsDouble();
if (FLOAT_EQ(upper_value, 1.0) || FLOAT_EQ(upper_value, 6.0)) {
fused->push_back(clamp.min().getDefiningOp());
fused->push_back(clamp.max().getDefiningOp());
fused->push_back(op);
return (FLOAT_EQ(upper_value, 1.0) ? RELU1 : RELU6);
}
}
return NONE;
}
llvm::SmallVector<Value, 0> FuseOps(PatternRewriter* rewriter,
const std::initializer_list<Value>& results,
StringRef kernel) {
// Collect all the operations to be fused.
llvm::SmallVector<Operation*, 4> fused;
llvm::SmallVector<Location, 4> locs;
fused.reserve(results.size());
locs.reserve(results.size());
for (auto value : results) {
Operation* op = value.getDefiningOp();
fused.push_back(op);
locs.push_back(op->getLoc());
}
Operation* root = fused.back();
FusedActivationFunc act_func = FusedActivationFunc::NONE;
// If there is Relu, Relu1 or Relu6, fuse it as well.
if (results.size() > 0 && std::rbegin(results)->hasOneUse()) {
act_func = FuseReluX(*std::rbegin(results)->user_begin(), &fused);
}
// Collect inputs from outside to 'ops'.
llvm::SmallVector<Value, 4> inputs;
llvm::SmallVector<Attribute, 4> input_specs;
CollectInputs(fused, &inputs, &input_specs);
// Collect outputs from 'ops' to outside.
llvm::SmallVector<Value, 4> rets;
llvm::SmallVector<Type, 4> ret_types;
llvm::SmallVector<Attribute, 4> ret_specs;
CollectRets(fused, &rets, &ret_types, &ret_specs);
// TODO(fengliuai): make activation function an attribute.
std::string kernel_name;
switch (act_func) {
case RELU:
kernel_name = llvm::Twine(kernel, "_relu").str();
break;
case RELU1:
kernel_name = llvm::Twine(kernel, "_relu1").str();
break;
case RELU6:
kernel_name = llvm::Twine(kernel, "_relu6").str();
break;
default:
kernel_name = kernel.str();
}
// Create the region op with the return.
auto region = rewriter->create<quant::QuantizeRegionOp>(
rewriter->getFusedLoc(locs), ret_types, inputs,
rewriter->getArrayAttr(input_specs), rewriter->getArrayAttr(ret_specs),
kernel_name);
auto* body = new Block();
region.body().push_back(body);
OpBuilder builder = OpBuilder::atBlockEnd(body);
BlockAndValueMapping mapping;
// Make block arguments and add it to the block value mapping.
for (Value input : inputs) {
mapping.map(input, body->addArgument(input.getType()));
}
// Clone the operations 'ops' to the region.
for (Operation* op : fused) {
builder.clone(*op, mapping);
}
llvm::SmallVector<Value, 4> new_rets;
new_rets.reserve(rets.size());
for (auto ret : llvm::enumerate(rets)) {
Value new_ret = mapping.lookupOrNull(ret.value());
assert(new_ret && "couldn't find return value.");
new_rets.push_back(new_ret);
ret.value().replaceAllUsesWith(region.getResult(ret.index()));
}
builder.create<quant::ReturnOp>(builder.getUnknownLoc(), new_rets);
LLVM_DEBUG({
assert(region.verify().Success && "failed to create quant region.");
llvm::dbgs() << "\ncreated region: ";
region.print(llvm::dbgs());
llvm::dbgs() << "\n\n\n";
});
// All uses of the fused ops are replaced, so the values in this vector
// will not be used.
SmallVector<Value, 0> new_values(root->getNumResults(), region.getResult(0));
return new_values;
}
struct CpuKernelFusionPass : public FunctionPass<CpuKernelFusionPass> {
explicit CpuKernelFusionPass() = default;
CpuKernelFusionPass(const CpuKernelFusionPass&) {}
void runOnFunction() override;
};
#include "tensorflow/compiler/mlir/lite/quantization/xla/generated_cpu_kernel_fusion.inc"
void CpuKernelFusionPass::runOnFunction() {
Operation* op = getOperation();
MLIRContext* ctx = op->getContext();
OwningRewritePatternList patterns;
populateWithGenerated(ctx, &patterns);
applyPatternsGreedily(op->getRegions(), patterns);
}
} // namespace
// Creates an instance of the xla_hlo cpu kernel fusion pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateCpuKernelFusionPass() {
return std::make_unique<CpuKernelFusionPass>();
}
static PassRegistration<CpuKernelFusionPass> pass(
"xla-hlo-cpu-fusion", "Fuse xla hlo ops into cpu kernels");
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,65 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
class Fused1Ops<string kernel> : NativeCodeCall<
"FuseOps(&$_builder, {$0}, \"" # kernel # "\")">;
class Fused2Ops<string kernel> : NativeCodeCall<
"FuseOps(&$_builder, {$0, $1}, \"" # kernel # "\")">;
class Fused3Ops<string kernel> : NativeCodeCall<
"FuseOps(&$_builder, {$0, $1, $2}, \"" # kernel # "\")">;
class Fused4Ops<string kernel> : NativeCodeCall<
"FuseOps(&$_builder, {$0, $1, $2, $3}, \"" # kernel # "\")">;
// We shouldn't revisit those ops which have been fused. This constraint is
// required because the greedy pattern rewriter will visit and match any new
// ops. So when the source pattern are matched and wrapped by the quant region
// op, these ops will be matched again. To prevent this, this constraint is
// added to bypass any ops inside a quant region.
def NeedsToBeFused : Constraint<CPred<
"!$0.getDefiningOp()->getParentOfType<quant::QuantizeRegionOp>()">>;
// dummy example
def : Pat<(HLO_AddOp:$add (HLO_MulOp:$mul $_, $_, $_), $_, $_),
(Fused2Ops<"generic.mul_add"> $mul, $add),
[(NeedsToBeFused $add)]>;
// reduce_window: maxpool, avgpool
def : Pat<(HLO_ReduceWindowOp:$reduce $_, $_, $_, $_, $_, $_, $_),
(Fused1Ops<"generic.reduce_window"> $reduce),
[(NeedsToBeFused $reduce)]>;
// reshape
def : Pat<(HLO_ReshapeOp:$reshape $_), (Fused1Ops<"generic.reshape"> $reshape),
[(NeedsToBeFused $reshape)]>;
// broadcast
def : Pat<(HLO_BroadcastInDimOp:$broadcast $_, $_),
(Fused1Ops<"generic.broadcast"> $broadcast),
[(NeedsToBeFused $broadcast)]>;
// dot -> add
def : Pat<(HLO_AddOp:$add (HLO_DotOp:$dot $_, $_, $_), $_, $_),
(Fused2Ops<"generic.biased_dot"> $dot, $add),
[(NeedsToBeFused $add)]>;
// conv -> add
def : Pat<(HLO_AddOp:$add
(HLO_ConvOp:$conv $_, $_, $_, $_, $_, $_, $_, $_, $_, $_), $_, $_),
(Fused2Ops<"generic.biased_conv"> $conv, $add),
[(NeedsToBeFused $add)]>;

View File

@ -1,174 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass quantize the constant and rewrite the quantization
// ops by xla_hlo primitive ops.
#include <cstdint>
#include <iterator>
#include <numeric>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
//===----------------------------------------------------------------------===//
// The pass to materialize the quantization results by xla primitive ops.
//
namespace mlir {
namespace xla_hlo {
namespace {
// This pattern matches the "constant->qcast->dcast" pattern and replaces it by
// "quantized constant->xla_hlo.dequantize". If it only matches the
// "non-constant->qcast->dcast" pattern, it will remove both the "qcast->dcast".
// We chain the pattern as a whole to bypass the type checks of the normal
// xla_hlo ops.
// TODO(fengliuai): make this pass work for bf16 input.
class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
public:
explicit RewriteDequantize(int64_t size, MLIRContext *context)
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
LogicalResult matchAndRewrite(quant::DequantizeCastOp op,
PatternRewriter &rewriter) const override {
// quant.dcast
// xla_hlo dequantize only takes min/max, so let's recover them from
// the quantization parameters.
Value dcast = op.arg();
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
if (!type || !type.isa<quant::UniformQuantizedType>()) {
return failure();
}
auto qtype = type.cast<quant::UniformQuantizedType>();
double scale = qtype.getScale();
int64_t zero_point = qtype.getZeroPoint();
float min = scale * (qtype.getStorageTypeMin() - zero_point);
float max = scale * (qtype.getStorageTypeMax() - zero_point);
// quant.qcast
auto qcast =
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
if (!qcast) return failure();
// constant
DenseFPElementsAttr attr;
// If it isn't a floating-point constant or the size is too small, let's
// remove the quantization. Also the last dimension size should be a
// multiplier of 4, so the shape isn't broken during packing and unpacking.
if (!matchPattern(qcast.arg(), m_Constant(&attr)) ||
attr.getNumElements() <= size_ ||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
op.getResult().replaceAllUsesWith(qcast.arg());
return success();
}
// TODO(fengliuai): implement transpose if it has high dimension.
// Create the quantized result
auto quantized_result =
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
if (!quantized_result) {
return failure();
}
// Pack the uint8 bits to uint32. The shape is changed from from
// [n0, n1, ..., nk] to [n0, n1, ..., nk / 4].
std::vector<uint8_t> raw_data;
for (auto d : quantized_result.getValues<uint8_t>()) {
raw_data.push_back(d);
}
// The packing might increase the data size by paddings.
auto packed_data = xla::PackToUint32<uint8_t>(raw_data);
auto packed_shape = attr.getType().getShape().vec();
int lower_dims = std::accumulate(
packed_shape.begin(),
std::next(packed_shape.begin(), packed_shape.size() - 1), 1,
std::multiplies<int>());
packed_shape[packed_shape.size() - 1] = packed_data.size() / lower_dims;
auto packed_type =
RankedTensorType::get(packed_shape, rewriter.getIntegerType(32));
auto packed_quantized_result =
DenseElementsAttr::get<uint32_t>(packed_type, packed_data);
auto quantized_constant =
rewriter.create<ConstantOp>(qcast.getLoc(), packed_quantized_result);
// Create the xla dequantize op with bf16 output
auto dequantized_type = RankedTensorType::get(attr.getType().getShape(),
rewriter.getBF16Type());
auto dequantize = rewriter.create<DequantizeOp>(
qcast.getLoc(), dequantized_type, quantized_constant,
rewriter.getF32FloatAttr(min), rewriter.getF32FloatAttr(max),
rewriter.getStringAttr("MIN_COMBINED"), rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
// Convert bf16 output back to f32
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
dequantize);
return success();
}
private:
int64_t size_;
};
// Materialize the quantization results by hlo primitive ops.
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
explicit MaterializeToXlaPass() = default;
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
void runOnFunction() override;
};
void MaterializeToXlaPass::runOnFunction() {
FuncOp func = getFunction();
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
// TODO(fengliuai): make the size 6 configurable.
patterns.insert<RewriteDequantize>(6, ctx);
applyPatternsGreedily(func, patterns);
}
} // namespace
// Creates an instance of the xla_hlo dialect quantization propagation pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
return std::make_unique<MaterializeToXlaPass>();
}
static PassRegistration<MaterializeToXlaPass> pass(
"xla-hlo-materialize-quant",
"Materialize the quantization results by xla primitve ops");
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,7 +0,0 @@
// TODO(fengliuai): automatically generate this file
// TODO(fengliuai): add all the xla_hlo ops
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
auto spec = absl::make_unique<quant::OpQuantSpec>();
return spec;
}

View File

@ -1,37 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
#include <memory>
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
namespace xla_hlo {
// Propagate the quantization information to all the tensors according to the
// op quant spec.
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass();
// Rewrite the graph and quantize the constant.
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass();
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_

View File

@ -1,107 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass applies quantization propagation on xla_hlo dialect.
#include <iterator>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_context.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/cpu_device_target.h"
// NOLINTNEXTLINE
static llvm::cl::opt<bool> disable_per_channel(
"xla-disable-per-channel", llvm::cl::value_desc("bool"),
llvm::cl::desc("Whether disable per-channel quantized weights."),
llvm::cl::init(false));
//===----------------------------------------------------------------------===//
// The quantization propagation Pass.
//
namespace mlir {
namespace xla_hlo {
namespace {
// Applies the quantization propagation on the input function. During the
// propagation, two facts are respected:
// - The quantization type (params) of the ops in the function
// - The quantization spec for the ops
// The propagation results should assign quantization types to all the tensors
// and the two restrictions are respected.
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
explicit PropagateQuantPass() = default;
PropagateQuantPass(const PropagateQuantPass &) {}
void runOnFunction() override;
};
#include "tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc"
void PropagateQuantPass::runOnFunction() {
FuncOp func = getFunction();
// TODO(fengliuai): deprecate this old code generation path.
// XLA only support uint8/uint16 quantization for now.
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
disable_per_channel, GetOpQuantSpec);
CpuDeviceTarget spec(&getContext());
quant::QuantizeContext ctx(func, spec);
std::vector<quant::QuantizeRegionOp> work_list = ctx.GetAllOps();
bool changed = false;
while (!work_list.empty()) {
quant::QuantizeRegionOp op = work_list.back();
work_list.pop_back();
llvm::SmallVector<Operation *, 4> new_items;
if (failed(ctx.Handle(op, &new_items, &changed))) {
// The IR is still valid, thus we shouldn't fail.
signalPassFailure();
}
for (auto item : new_items) {
if (auto reg = llvm::dyn_cast_or_null<quant::QuantizeRegionOp>(item))
work_list.push_back(reg);
}
}
if (!changed) return;
if (failed(ctx.Finalize())) {
signalPassFailure();
}
}
} // namespace
// Creates an instance of the xla_hlo dialect quantization propagation pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
return std::make_unique<PropagateQuantPass>();
}
static PassRegistration<PropagateQuantPass> pass(
"xla-hlo-propagate-quant", "Propagate quantization information");
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,74 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
namespace mlir {
namespace xla_hlo {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
return true;
}();
(void)init_once;
}
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
computation->Snapshot());
RegisterDialects();
MLIRContext context;
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
auto status = xla::ConvertHloToMlirHlo(
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
if (!status.ok()) {
LOG(ERROR) << "Hlo module import failed: " << status;
return status;
}
PassManager pm(&context);
pm.addPass(createCanonicalizerPass());
pm.addPass(createInlinerPass());
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<FuncOp>(createCSEPass());
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
LogicalResult result = pm.run(module.get());
(void)result;
module->dump();
return tensorflow::Status::OK();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -1,35 +0,0 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [
":graph_config_files",
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
},
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/aot:tfcompile",
"//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)
# Bundle together all the graph files that are used by the tests.
filegroup(
name = "graph_config_files",
srcs = glob(
["**/*.pbtxt"],
),
)

View File

@ -1,199 +0,0 @@
// RUN: tf-opt -xla-hlo-cpu-fusion %s | FileCheck %s
// CHECK-LABEL: @mul_add_source
func @mul_add_source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
// CHECK: %[[region:.*]] = "quant.region"(%arg0, %arg1, %arg2) ( {
// CHECK: ^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
// CHECK: "quant.return"(%[[add]]) : (tensor<4xf32>) -> ()
// CHECK: }) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[region]] : tensor<4xf32>
}
// CHECK-LABEL: @mul_add_annotated
func @mul_add_annotated(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) {
%cst = constant dense<0.0> : tensor<f32>
%cst_0 = constant dense<255.0> : tensor<f32>
%cst_1 = constant dense<8> : tensor<i32>
%cst_2 = constant dense<false> : tensor<i1>
%qin = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.1"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
%qw = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.2"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
%0 = "xla_hlo.multiply"(%qin, %qw) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
%1 = "xla_hlo.add"(%0, %arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
%r = "xla_hlo.custom_call"(%1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars",
has_side_effect = false, name = "custom-call.3"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
return %r : tensor<2x4xf32>
// CHECK: %[[region:.*]] = "quant.region"
// CHECK: ^bb0(%arg3: tensor<2x4xf32>, %arg4: tensor<2x4xf32>, %arg5: tensor<2x4xf32>): // no predecessors
// CHECK: %[[mul:.*]] = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
// CHECK: %[[add:.*]] = xla_hlo.add %[[mul]], %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<2x4xf32>
// CHECK: "quant.return"(%[[add]]) : (tensor<2x4xf32>) -> ()
// CHECK: }) {input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32],
// CHECK-SAME: logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]} :
// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[r:.*]] = "xla_hlo.custom_call"(%[[region]]
// CHECK: return %[[r]] : tensor<2x4xf32>
}
// CHECK-LABEL: @reduce_window
func @reduce_window(%arg0: tensor<1x28x28x32xf32>, %arg1: tensor<f32>) -> (tensor<1x14x14x32xf32>) {
%0 = "xla_hlo.reduce_window"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%1 = xla_hlo.maximum %arg2, %arg3 : tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<1x28x28x32xf32>, tensor<f32>) -> tensor<1x14x14x32xf32>
return %0 : tensor<1x14x14x32xf32>
// CHECK: "quant.region"(%arg0, %arg1) ( {
// CHECK: ^bb0(%arg2: tensor<1x28x28x32xf32>, %arg3: tensor<f32>): // no predecessors
// CHECK: %[[rw:.*]] = "xla_hlo.reduce_window"(%arg2, %arg3) ( {
// CHECK: ^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>): // no predecessors
// CHECK: %[[max:.*]] = xla_hlo.maximum %arg4, %arg5 : tensor<f32>
// CHECK: "xla_hlo.return"(%[[max]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK: "quant.return"(%[[rw]])
// CHECK: }) {input_specs = [f32, f32], logical_kernel = "generic.reduce_window", output_specs = [f32]}
}
// CHECK-LABEL: @reshape
func @reshape(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32>
return %0 : tensor<1x3136xf32>
// CHECK: "quant.region"(%arg0)
// CHECK: logical_kernel = "generic.reshape"
}
// CHECK-LABEL: @broadcast
func @broadcast(%arg0: tensor<64xf32>) -> (tensor<1x14x14x64xf32>) {
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<1x14x14x64xf32>
return %0 : tensor<1x14x14x64xf32>
// CHECK: "quant.region"(%arg0)
// CHECK: logical_kernel = "generic.broadcast"
}
// CHECK-LABEL: @biased_dot
func @biased_dot(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x10xf32>, %arg2: tensor<1x10xf32>) -> (tensor<1x10xf32>) {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1024xf32>, tensor<1024x10xf32>) -> tensor<1x10xf32>
%1 = xla_hlo.add %0, %arg2 : tensor<1x10xf32>
return %1 : tensor<1x10xf32>
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
// CHECK: xla_hlo.dot
// CHECK: xla_hlo.add
// CHECK: logical_kernel = "generic.biased_dot"
}
// CHECK-LABEL: @biased_conv
func @biased_conv(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) {
%0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>,
padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
} : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32>
%1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32>
return %1 : tensor<1x14x14x64xf32>
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
// CHECK: xla_hlo.conv
// CHECK: xla_hlo.add
// CHECK: logical_kernel = "generic.biased_conv"
}
// CHECK-LABEL: @biased_dot_relu
func @biased_dot_relu(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x10xf32>, %arg2: tensor<1x10xf32>) -> (tensor<1x10xf32>) {
%cst = constant dense<0.0> : tensor<1x10xf32>
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1024xf32>, tensor<1024x10xf32>) -> tensor<1x10xf32>
%1 = xla_hlo.add %0, %arg2 : tensor<1x10xf32>
%2 = xla_hlo.maximum %1, %cst : tensor<1x10xf32>
return %2 : tensor<1x10xf32>
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
// CHECK: constant
// CHECK: xla_hlo.dot
// CHECK: xla_hlo.add
// CHECK: xla_hlo.maximum
// CHECK: logical_kernel = "generic.biased_dot_relu"
}
// CHECK-LABEL: @biased_conv_relu
func @biased_conv_relu(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) {
%cst = constant dense<0.0> : tensor<1x14x14x64xf32>
%0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>,
padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
} : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32>
%1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32>
%2 = xla_hlo.maximum %1, %cst : tensor<1x14x14x64xf32>
return %2 : tensor<1x14x14x64xf32>
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
// CHECK: constant
// CHECK: xla_hlo.conv
// CHECK: xla_hlo.add
// CHECK: xla_hlo.maximum
// CHECK: logical_kernel = "generic.biased_conv_relu"
}
// CHECK-LABEL: @biased_conv_relu_shared
func @biased_conv_relu_shared(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>) {
%cst = constant dense<0.0> : tensor<1x14x14x64xf32>
%0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>,
padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
} : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32>
%1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32>
%2 = xla_hlo.maximum %1, %cst : tensor<1x14x14x64xf32>
return %cst, %2 : tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
// CHECK: constant
// CHECK: xla_hlo.conv
// CHECK: xla_hlo.add
// CHECK: %[[max:.*]] = xla_hlo.maximum
// CHECK: "quant.return"(%[[max]])
// CHECK: logical_kernel = "generic.biased_conv_relu"
}
// CHECK-LABEL: @biased_conv_relu6
func @biased_conv_relu6(%arg0: tensor<1x14x14x32xf32>, %arg1: tensor<5x5x32x64xf32>, %arg2: tensor<1x14x14x64xf32>) -> (tensor<1x14x14x64xf32>) {
%min = constant dense<0.0> : tensor<1x14x14x64xf32>
%max = constant dense<6.0> : tensor<1x14x14x64xf32>
%0 = "xla_hlo.conv"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>,
padding = dense<2> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilations = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
} : (tensor<1x14x14x32xf32>, tensor<5x5x32x64xf32>) -> tensor<1x14x14x64xf32>
%1 = xla_hlo.add %0, %arg2 : tensor<1x14x14x64xf32>
%2 = "xla_hlo.clamp"(%min, %1, %max) : (tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>, tensor<1x14x14x64xf32>) -> tensor<1x14x14x64xf32>
return %2 : tensor<1x14x14x64xf32>
// CHECK: "quant.region"(%arg0, %arg1, %arg2)
// CHECK: constant
// CHECK: constant
// CHECK: xla_hlo.conv
// CHECK: xla_hlo.add
// CHECK: xla_hlo.clamp
// CHECK: logical_kernel = "generic.biased_conv_relu6"
}

View File

@ -1,15 +0,0 @@
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --experimental_quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
# TODO(fengliuai): update this file with the progress of the implementation
// CHECK: func @main
// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
// CHECK: %cst_1 = constant dense<8> : tensor<i32>
// CHECK: %cst_2 = constant dense<false> : tensor<i1>
// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
// CHECK: return %4 : tuple<tensor<2x4xf32>>
// CHECK: }

View File

@ -1,26 +0,0 @@
feed {
id { node_name: "input0" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
feed {
id { node_name: "input1" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
fetch {
id { node_name: "Add/FakeQuantWithMinMaxVars" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
conversion_options {
custom_fake_quant_op_calls: true
}

View File

@ -1,218 +0,0 @@
node: {
name: "Add/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "Add"
input: "Add/FakeQuantWithMinMaxVars/min"
input: "Add/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "Add/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "Add/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "Add"
op: "Add"
input: "input0/FakeQuantWithMinMaxVars"
input: "input1/FakeQuantWithMinMaxVars"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "input0"
input: "input0/FakeQuantWithMinMaxVars/min"
input: "input0/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "input1"
input: "input1/FakeQuantWithMinMaxVars/min"
input: "input1/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "input1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 27
}

View File

@ -1,54 +0,0 @@
// RUN: tf-opt -xla-hlo-materialize-quant %s | FileCheck %s
// CHECK-LABEL: func @quantize_rewrite
func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK: %[[qcst:.*]] = constant dense<{{\[\[}}21004416], [-1056997248]]> : tensor<2x1xi32>
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[cast]] : tensor<2x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
return %mul: tensor<2x4xf32>
}
// CHECK-LABEL: func @quantize_small
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<1x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
%w = constant dense<1.0> : tensor<1x4xf32>
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<1x4xf32>
return %mul: tensor<1x4xf32>
}
// CHECK-LABEL: func @quantize_non_cst
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %arg0 : tensor<2x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x4xf32>
return %mul: tensor<2x4xf32>
}
// CHECK-LABEL: func @quantize_non_4x
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[w]] : tensor<2x5xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
%w = constant dense<1.0> : tensor<2x5xf32>
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
%mul = xla_hlo.multiply %arg0, %dq : tensor<2x5xf32>
return %mul: tensor<2x5xf32>
}

View File

@ -1,69 +0,0 @@
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s --dump-input-on-failure
// -----
// CHECK-LABEL: @mul_add_source_no_params
func @mul_add_source_no_params(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [f32, f32, f32], logical_kernel = "generic.mul_add", output_specs = [f32]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [f32, f32, f32]
// CHECK-SAME: output_specs = [f32]
}
// -----
// CHECK-LABEL: @mul_add_annotated_no_narrow_range
func @mul_add_annotated_no_narrow_range(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8:f32, 1.0:-128>, f32],
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8:f32, 1.000000e+00:-128>, f32]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
}
// -----
// CHECK-LABEL: @mul_add_annotated
func @mul_add_annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
%region = "quant.region"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<4xf32>, %arg4: tensor<4xf32>, %arg5: tensor<4xf32>): // no predecessors
%mul = xla_hlo.multiply %arg3, %arg4 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
%add = xla_hlo.add %mul, %arg5 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<4xf32>
"quant.return"(%add) : (tensor<4xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0:-128>, !quant.uniform<i8<-127:127>:f32, 1.0:-128>, f32],
logical_kernel = "generic.mul_add", output_specs = [!quant.uniform<i8:f32, 1.0:-128>]} :
(tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %region : tensor<4xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>, !quant.uniform<i8<-127:127>:f32, 1.000000e+00:-128>, !quant.uniform<i32:f32, 1.000000e+00>]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00:-128>]
}
// -----
// CHECK-LABEL: @same_scale_1_1
func @same_scale_1_1(%arg0: tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>) {
%region = "quant.region"(%arg0) ( {
^bb0(%arg1: tensor<1x7x7x64xf32>): // no predecessors
%r = "xla_hlo.reshape"(%arg1) : (tensor<1x7x7x64xf32>) -> (tensor<1x3136xf32>)
"quant.return"(%r) : (tensor<1x3136xf32>) -> ()
}) {input_specs = [!quant.uniform<i8:f32, 1.0>], logical_kernel = "generic.reshape", output_specs = [f32]} : (tensor<1x7x7x64xf32>) -> tensor<1x3136xf32>
return %region : tensor<1x3136xf32>
// CHECK: input_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
// CHECK-SAME: output_specs = [!quant.uniform<i8:f32, 1.000000e+00>]
}

View File

@ -1,25 +0,0 @@
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s
// CHECK-LABEL: func @mul
func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %arg0, %[[dq]] : tensor<2x2xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
%mul = xla_hlo.multiply %arg0, %w : tensor<2x2xf32>
return %mul: tensor<2x2xf32>
}
// CHECK-LABEL: func @add
func @add(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[b:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[b]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>) -> tensor<2xf32>
// CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg0, %[[dq]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[add]] : tensor<2x2xf32>
%b = constant dense<1.0> : tensor<2xf32>
%add = "xla_hlo.add"(%arg0, %b) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
return %add: tensor<2x2xf32>
}

View File

@ -39,7 +39,7 @@ versions {
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME inputs = "input0,input1"
# CHECK-SAME: inputs = "input0,input1"
# CHECK-SAME: outputs = "output"
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>

View File

@ -12,6 +12,7 @@ glob_lit_tests(
test_file_exts = [
"mlir",
"cc",
"json",
],
)
@ -24,6 +25,8 @@ filegroup(
":importer_test_min_max",
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"//tensorflow/compiler/mlir/lite:json_to_flatbuffer",
"//tensorflow/lite/schema:schema.fbs",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,83 @@
// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>
// CHECK: return %[[RES0]] : tensor<256x32x32x16xf32>
{
version: 3,
operator_codes: [
{
builtin_code: "CONV_2D",
}
],
subgraphs: [
{
tensors: [
{
shape: [
256,
32,
32,
3
],
name: "arg0",
quantization: {
}
},
{
shape: [
16,
3,
3,
3
],
name: "arg1",
quantization: {
}
},
{
shape: [
0
],
name: "cst"
},
{
shape: [
256,
32,
32,
16
],
name: "output",
quantization: {
}
},
],
inputs: [
0,
1
],
outputs: [
3
],
operators: [
{
inputs: [
0,
1,
-1
],
outputs: [
3
],
builtin_options_type: "Conv2DOptions",
builtin_options: {
}
}
],
name: "main"
}
],
description: "MLIR Converted."
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail
// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s --dump-input=fail
// Inline a function that contains only tfl ops.
func @func_with_tfl_ops(%arg0 : tensor<2xi32>) -> tensor<2xi32> {

View File

@ -1,5 +1,5 @@
// RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s --dump-input-on-failure
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline --mlir-disable-inline-simplify | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline="disable-simplify" | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline | FileCheck %s --dump-input-on-failure --check-prefix=CANON
func @while_main(%arg0: tensor<?x256x256xf32>) -> (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} {

View File

@ -9,6 +9,20 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: return
}
// CHECK-LABEL: testAddHighDimsHaveSameShape
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
}
// CHECK-LABEL: testAddTooHighBroadcastableDims
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32>
@ -1448,7 +1462,7 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK return [[MUL]] : tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
@ -1459,5 +1473,5 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK return [[MUL]] : tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}

View File

@ -29,7 +29,7 @@ limitations under the License.
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir
@ -134,6 +134,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
if (pass_config.shape_inference) {
// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
// Legalize while early to allow further constant folding.
// TODO(jpienaar): This may not actually matter as we do canonicalization
// after the legalize below, for now it needs to be below the above passes
@ -160,11 +164,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// constant ops.
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
if (pass_config.shape_inference) {
// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
// The below passes only make sense if Builtin TFLite ops are enabled
// for emission.
if (pass_config.emit_builtin_tflite_ops) {
@ -173,7 +172,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
pass_manager->addPass(
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops
@ -255,7 +255,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
// TFLite dialect passes.
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addPass(mlir::TFL::CreateLegalizeTFPass());
pm.addPass(
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
pm.addPass(mlir::TFL::CreateOptimizePass());
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
@ -268,7 +269,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
}
// Registers a pass pipeline for the standard TFL passes.

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iostream>
#include "absl/strings/str_split.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
@ -214,7 +216,7 @@ int main(int argc, char **argv) {
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(

View File

@ -40,7 +40,7 @@ limitations under the License.
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OpPassBase<FuncOp>>
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir

View File

@ -44,7 +44,8 @@ namespace TFL {
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
namespace {
class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
class DefaultQuantParamsPass
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
public:
explicit DefaultQuantParamsPass(double default_min, double default_max)
: default_min_(default_min), default_max_(default_max) {}
@ -220,7 +221,7 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
}

View File

@ -29,7 +29,7 @@ namespace TFL {
namespace {
struct DenseToSparse : public FunctionPass<DenseToSparse> {
struct DenseToSparse : public PassWrapper<DenseToSparse, FunctionPass> {
void runOnFunction() override;
};
@ -63,7 +63,7 @@ void DenseToSparse::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect DenseToSparse pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateDenseToSparsePass() {
return absl::make_unique<DenseToSparse>();
}

View File

@ -18,7 +18,8 @@ namespace mlir {
namespace TFL {
namespace {
struct IdentifyDilatedConvPass : public FunctionPass<IdentifyDilatedConvPass> {
struct IdentifyDilatedConvPass
: public PassWrapper<IdentifyDilatedConvPass, FunctionPass> {
void runOnFunction() override;
};

View File

@ -679,7 +679,8 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
return success();
}
struct ExtractOphintPass : public OperationPass<ExtractOphintPass, ModuleOp> {
struct ExtractOphintPass
: public PassWrapper<ExtractOphintPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
void Verify();
@ -752,7 +753,7 @@ void ExtractOphintPass::Verify() {
/// Creates an instance of the TensorFlow Lite dialect ExtractOphintPass
/// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOphintPass() {
return std::make_unique<ExtractOphintPass>();
}

View File

@ -69,7 +69,7 @@ constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm";
// |
// OutputOp1
struct LegalizeOphintFuncOpPass
: public OperationPass<LegalizeOphintFuncOpPass, ModuleOp> {
: public PassWrapper<LegalizeOphintFuncOpPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -284,7 +284,7 @@ void LegalizeOphintFuncOpPass::runOnOperation() {
/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
/// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeOphintFuncOpPass() {
return std::make_unique<LegalizeOphintFuncOpPass>();
}

View File

@ -70,8 +70,21 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
// Legalize operations in functions.
struct LegalizeTF : public FunctionPass<LegalizeTF> {
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;
LegalizeTF(const LegalizeTF&) {}
explicit LegalizeTF(bool run_tfl_runtime_verification) {
run_tfl_runtime_verification_ = run_tfl_runtime_verification;
}
/// Performs the lowering to TFLite dialect.
void runOnFunction() override;
private:
Option<bool> run_tfl_runtime_verification_{
*this, "run-tfl-runtime-verification",
llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
};
// Returns true if all tensor value in `values` has static shape and same shape.
@ -314,7 +327,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
// can't do any padding. Instead we just return it.
return attribute;
}
for (auto idx : dense_elem_attr.getIntValues()) {
for (const auto& idx : dense_elem_attr.getIntValues()) {
padded_val.push_back(idx.getSExtValue());
}
auto attr_dim_count = ranked_attr_type.getShape()[0];
@ -440,7 +453,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
if (!matchPattern(tf_matrix_diag_v2_or_v3_op.padding_value(),
m_Constant(&padding_value)))
return false;
for (auto value : padding_value.getValues<APInt>()) {
for (const auto& value : padding_value.getValues<APInt>()) {
if (value != 0) return false;
}
@ -741,13 +754,19 @@ void LegalizeTF::runOnFunction() {
// graph.
target.addLegalOp<mlir::ConstantOp>();
target.addLegalOp<ConstOp>();
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
if (!tfl_op) return false;
return succeeded(tfl_op.VerifyTflRuntimeTypes(
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
}));
if (run_tfl_runtime_verification_) {
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
[](Operation* op) {
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
if (!tfl_op) return false;
return succeeded(tfl_op.VerifyTflRuntimeConstraints(
tfl_op.getOperation(),
/*failure_on_operand_type_mismatch=*/false));
}));
} else {
target.addLegalDialect<TensorFlowLiteDialect>();
}
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.
@ -763,8 +782,9 @@ void LegalizeTF::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFPass() {
return std::make_unique<LegalizeTF>();
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
bool run_tfl_runtime_verification) {
return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
}
static PassRegistration<LegalizeTF> pass(

View File

@ -31,7 +31,8 @@ namespace {
// Legalize TF While to TFL While with calls to the original functions from the
// cond and body regions.
struct LegalizeWhile : public OperationPass<LegalizeWhile, ModuleOp> {
struct LegalizeWhile
: public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
void RunOnFunction(FuncOp func);
void runOnOperation() override {
@ -76,7 +77,7 @@ void LegalizeWhile::RunOnFunction(FuncOp func) {
}
// Creates an instance of the TensorFlow While to TFLite While pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass() {
return std::make_unique<LegalizeWhile>();
}

View File

@ -42,7 +42,8 @@ namespace {
// AnyQuantizedType, thus bitwidth, narrow_range, etc are included. The op also
// defines the op quantization traits, which are used to propagate the
// quantization parameters by the following passes.
struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
struct LoadQuantizationRecipe
: public PassWrapper<LoadQuantizationRecipe, FunctionPass> {
void runOnFunction() override;
private:
@ -215,7 +216,7 @@ void LoadQuantizationRecipe::runOnFunction() {
// Creates an instance of the TensorFlow Lite dialect LoadQuantizationRecipe
// pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLoadQuantizationRecipePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateLoadQuantizationRecipePass() {
return absl::make_unique<LoadQuantizationRecipe>();
}

View File

@ -82,7 +82,7 @@ class TensorListPatternRewriter : public PatternRewriter {
/// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass
: public OperationPass<LowerStaticTensorListPass, ModuleOp> {
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
// Apply type and op changes within a function.
@ -720,7 +720,7 @@ struct ConvertTensorListStack
RankedTensorType::get({-1}, rewriter.getIntegerType(32));
auto new_shape = rewriter.create<TF::ShapeOp>(loc, shape_type, input);
SmallVector<int64_t, 8> output_shape = {op.num_elements().getSExtValue()};
for (auto dim : dense_elem_attr.getIntValues())
for (const auto &dim : dense_elem_attr.getIntValues())
output_shape.push_back(dim.getSExtValue());
RankedTensorType result_type =
RankedTensorType::get(output_shape, getElementTypeOrSelf(input));
@ -906,7 +906,8 @@ void LowerStaticTensorListPass::runOnOperation() {
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
/// pass.
std::unique_ptr<OpPassBase<ModuleOp>> TFL::CreateLowerStaticTensorListPass() {
std::unique_ptr<OperationPass<ModuleOp>>
TFL::CreateLowerStaticTensorListPass() {
return std::make_unique<LowerStaticTensorListPass>();
}

View File

@ -74,7 +74,7 @@ bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
using ::llvm::cast;
// Optimize TFLite operations in functions.
struct Optimize : public FunctionPass<Optimize> {
struct Optimize : public PassWrapper<Optimize, FunctionPass> {
void runOnFunction() override;
};
@ -650,7 +650,7 @@ struct ConvertTrivialTransposeOpToReshapeOp
auto input_shape = input_type.getShape();
SmallVector<int64_t, 8> perm_values;
for (auto dim : perm_values_attr.getIntValues())
for (const auto &dim : perm_values_attr.getIntValues())
perm_values.push_back(dim.getSExtValue());
// This should never happen unless the input graph is malformed.
@ -725,7 +725,7 @@ void Optimize::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateOptimizePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass() {
return std::make_unique<Optimize>();
}

View File

@ -36,7 +36,7 @@ using FuncSet = llvm::SmallSet<FuncOp, 4>;
// Module pass to optimize TensorFlow functional ops.
struct OptimizeFunctionalOpsPass
: public OperationPass<OptimizeFunctionalOpsPass, ModuleOp> {
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
@ -198,7 +198,7 @@ void OptimizeFunctionalOpsPass::runOnOperation() {
}
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
return std::make_unique<OptimizeFunctionalOpsPass>();
}

View File

@ -24,75 +24,79 @@ namespace mlir {
class FuncOp;
class ModuleOp;
template <typename T>
class OpPassBase;
class OperationPass;
namespace TFL {
class QuantizationSpecs;
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFPass();
// When the given run_tfl_runtime_verification value is true, it will check each
// TFL builtin op towards the TFL runtime capability and the incompatible TF ops
// will be left in the graph without getting legalized.
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
bool run_tfl_runtime_verification);
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateOptimizePass();
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareTFPass(
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
bool unfold_batch_matmul);
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLowerStaticTensorListPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateLowerStaticTensorListPass();
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass();
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass();
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
const QuantizationSpecs& quant_specs);
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
bool emit_quant_adaptor_ops);
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
llvm::ArrayRef<std::string> trim_funcs_whitelist);
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass();
// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateExtractOphintPass();
// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass
// pass. The composite op is created from the ophint extraction pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeOphintFuncOpPass();
// Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass.
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass();
std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass();
// Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
// Creates an instance of the TensorFlow Lite dialect pass to add default
// quantization parameters.
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max);
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
// tensor to sparse format.
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass();
std::unique_ptr<OperationPass<FuncOp>> CreateDenseToSparsePass();
// Creates function pass to legalize TF While to TFL While.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeTFWhilePass();
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass();
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass();
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime supports types used.
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass();
// Verifies runtime constraints.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
} // namespace TFL

View File

@ -30,7 +30,7 @@ namespace TFL {
namespace {
// Applies all the clean up steps after quantization.
class PostQuantizePass : public FunctionPass<PostQuantizePass> {
class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
public:
// Constructor used by the PassRegistration. This will remove the adaptor ops.
explicit PostQuantizePass() : emit_quant_adaptor_ops_(false) {}
@ -135,7 +135,7 @@ void PostQuantizePass::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
bool emit_quant_adaptor_ops) {
return std::make_unique<PostQuantizePass>(emit_quant_adaptor_ops);
}

View File

@ -94,7 +94,8 @@ class ConvertEmbeddedLookupFunc {
// body with the corresponding fused TFLite op. The replacement need not always
// be a fused op, though that is the primary use case.
class PrepareCompositeFunctionsPass
: public OperationPass<PrepareCompositeFunctionsPass, ModuleOp> {
: public PassWrapper<PrepareCompositeFunctionsPass,
OperationPass<ModuleOp>> {
public:
explicit PrepareCompositeFunctionsPass() {}
@ -211,7 +212,7 @@ void PrepareCompositeFunctionsPass::runOnOperation() {
}
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
return std::make_unique<PrepareCompositeFunctionsPass>();
}

View File

@ -66,7 +66,8 @@ namespace {
// across ops. This step is necessary for post-training quantization and also
// making the quantization rule for some operations in the quantization-aware
// training quantization simpler.
class PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
class PrepareQuantizePass
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
public:
// Constructor used by the PassRegistration and enforce uint8 quantization.
explicit PrepareQuantizePass() {
@ -281,7 +282,7 @@ void PrepareQuantizePass::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareQuantizePass(
const QuantizationSpecs& quant_specs) {
return std::make_unique<PrepareQuantizePass>(quant_specs);
}

View File

@ -71,7 +71,7 @@ namespace TFL {
namespace {
// Prepare TF operations in functions for subsequent legalization.
class PrepareTFPass : public FunctionPass<PrepareTFPass> {
class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
public:
explicit PrepareTFPass() : unfold_batch_matmul_(true) {}
explicit PrepareTFPass(bool unfold_batch_matmul)
@ -652,7 +652,7 @@ void PrepareTFPass::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareTFPass(
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
bool unfold_batch_matmul) {
return std::make_unique<PrepareTFPass>(unfold_batch_matmul);
}

View File

@ -75,7 +75,7 @@ struct TFLFullQuantization
};
// Applies quantization on the model in TFL dialect.
struct QuantizePass : public FunctionPass<QuantizePass> {
struct QuantizePass : public PassWrapper<QuantizePass, FunctionPass> {
void runOnFunction() override;
};
@ -93,7 +93,7 @@ void QuantizePass::runOnFunction() {
} // namespace
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass() {
std::unique_ptr<OperationPass<FuncOp>> CreateQuantizePass() {
return std::make_unique<QuantizePass>();
}

View File

@ -22,33 +22,32 @@ namespace mlir {
namespace TFL {
namespace {
// This pass verifies that the operands and results types are supported by
// TFLite runtime.
class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
// This pass verifies that the TFL ops meet the TFL runtime constraints.
class RuntimeVerifyPass
: public mlir::PassWrapper<RuntimeVerifyPass, FunctionPass> {
public:
explicit RuntimeTypeVerifyPass() {}
explicit RuntimeVerifyPass() {}
private:
void runOnFunction() override;
};
void RuntimeTypeVerifyPass::runOnFunction() {
void RuntimeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
if (failed(op.VerifyTflRuntimeTypes(
op.getOperation(),
/*failure_on_operand_type_mismatch=*/true)))
if (failed(op.VerifyTflRuntimeConstraints(
op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
signalPassFailure();
});
}
} // namespace
// Verifies runtime supports types used.
std::unique_ptr<OpPassBase<FuncOp>> CreateRuntimeTypeVerifyPass() {
return std::make_unique<RuntimeTypeVerifyPass>();
// Verifies TFL runtime constraints.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass() {
return std::make_unique<RuntimeVerifyPass>();
}
static PassRegistration<RuntimeTypeVerifyPass> pass(
"tfl-runtime-verify", "TFLite runtime verification");
static PassRegistration<RuntimeVerifyPass> pass("tfl-runtime-verify",
"TFLite runtime verification");
} // namespace TFL
} // namespace mlir

View File

@ -66,7 +66,8 @@ namespace mlir {
namespace TFL {
namespace {
struct SplitMergedOperandsPass : public FunctionPass<SplitMergedOperandsPass> {
struct SplitMergedOperandsPass
: public PassWrapper<SplitMergedOperandsPass, FunctionPass> {
void runOnFunction() override;
};
@ -119,7 +120,7 @@ void SplitMergedOperandsPass::runOnFunction() {
/// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
/// pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass() {
std::unique_ptr<OperationPass<FuncOp>> CreateSplitMergedOperandsPass() {
return std::make_unique<SplitMergedOperandsPass>();
}

View File

@ -45,7 +45,7 @@ namespace {
// The pass to trim functions before we legalize to TFL
// dialect using the specified whitelist.
class TrimFunctionsPass
: public mlir::OperationPass<TrimFunctionsPass, ModuleOp> {
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
public:
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
@ -120,7 +120,7 @@ void TrimFunctionsPass::Verify() {
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
/// pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
}

View File

@ -38,7 +38,7 @@ namespace {
// This pass outlines the cond/body region of the TFL WhileOp into functions and
// replaces the regions with calls to these outlined functions.
class WhileOutlinePass
: public mlir::OperationPass<WhileOutlinePass, ModuleOp> {
: public mlir::PassWrapper<WhileOutlinePass, OperationPass<ModuleOp>> {
public:
explicit WhileOutlinePass() {}
@ -241,7 +241,7 @@ void WhileOutlinePass::runOnOperation() {
}
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
return std::make_unique<WhileOutlinePass>();
}

View File

@ -71,7 +71,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [
tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
'mlir-tflite-runner', 'tfcompile'
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -1078,6 +1078,7 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
@ -1118,6 +1119,7 @@ tf_cc_test(
srcs = ["utils/compile_mlir_util_test.cc"],
deps = [
":compile_mlir_util",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:test",
@ -1329,6 +1331,7 @@ cc_library(
deps = [
":tensorflow",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",

View File

@ -108,8 +108,6 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
>();
addOperations<ParallelExecuteOp>();
addInterfaces<TFInlinerInterface>();
}
@ -161,22 +159,8 @@ LogicalResult Verify(ParallelExecuteOp op) {
int output_index = 0;
for (auto& region_and_index : llvm::enumerate(regions)) {
auto& region = region_and_index.value();
auto region_index = region_and_index.index();
// Each region must include a single block of ops and must not be empty.
if (region.empty()) {
return op.emitOpError()
<< "regions must not be empty. "
<< "Found an empty region (" << region_index << ").";
}
if (!has_single_element(region)) {
return op.emitOpError()
<< "regions must be composed of a single block of operations."
<< "Expected region (" << region_index << ") with 1 block.";
}
auto* region_terminator = region.front().getTerminator();
// Check that output types of regions match return operand types.
for (auto result_type : region_terminator->getOperandTypes()) {
if (result_type !=
@ -214,8 +198,6 @@ void ParallelExecuteOp::build(Builder* builder, OperationState& state,
state.addTypes(output_types);
}
LogicalResult ParallelExecuteOp::verify() { return Verify(*this); }
Block& ParallelExecuteOp::GetRegionBlockWithIndex(unsigned index) {
return getOperation()->getRegion(index).front();
}

View File

@ -43,47 +43,6 @@ class TensorFlowDeviceDialect : public Dialect {
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h.inc"
// TODO(b/148642767): Use tablegen to define tf_device.parallel_execute op once
// variadic regions can be expressed in tablegen.
//
// ParallelExecute op concurrently executes variadic number of regions. Regions
// must represent separate sets of instructions to execute concurrently. In
// order to represent concurrently executed regions with dependencies, multiple
// ParallelExecute ops can be used instead. As so, regions within
// ParallelExecute op must not have control/data dependencies. While explicit
// dependencies between regions are disallowed, ParallelExecute op does not
// prevent implicit communication between regions (e.g. communication via
// send/recvs). In this case, users of ParallelExecute op must provide correct
// control dependencies between regions to guarantee correctness. Regions in
// ParallelExecute may include Resource ops. In the case where different regions
// include ops access the same resource, the users of the ParallelExecute op
// must provide mechanism (via send/recvs or via control dependencies) to
// guarantee correct ordering. Sequential ordering of ops within a region is
// guaranteed. Also, sequential ordering of ops before/after ParallelExecute ops
// are guaranteed. That is, execution of regions inside ParallelExecute op is
// blocked until all inputs to all regions are materialized and ops following
// ParallelExecute op are blocked until all regions are executed.
class ParallelExecuteOp
: public Op<ParallelExecuteOp,
OpTrait::SingleBlockImplicitTerminator<ReturnOp>::Impl> {
public:
using Op::Op;
static void build(Builder* builder, OperationState& state, int num_regions,
llvm::ArrayRef<Type> output_types);
static StringRef getOperationName() { return "tf_device.parallel_execute"; }
LogicalResult verify();
Block& GetRegionBlockWithIndex(unsigned index);
Operation::result_range GetRegionOutputs(unsigned region_index);
// Checks if a tf_device.parallel_execute index'th region block wraps a single
// operation and the single operation results are perfectly forwarded to the
// region block's return.
bool RegionWrapsSingleOp(unsigned index);
};
} // namespace tf_device
} // namespace mlir

View File

@ -125,6 +125,55 @@ def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> {
}];
}
def TfDevice_ParallelExecuteOp : TfDevice_Op<"parallel_execute",
[SingleBlockImplicitTerminator<"ReturnOp">]> {
let description = [{
ParallelExecute op concurrently executes variadic number of regions. Regions
must represent separate sets of instructions to execute concurrently. In
order to represent concurrently executed regions with dependencies, multiple
ParallelExecute ops can be used instead. As so, regions within
ParallelExecute op must not have control/data dependencies.
While explicit dependencies between regions are disallowed, ParallelExecute
op does not prevent implicit communication between regions (e.g.
communication via send/recvs). In this case, users of ParallelExecute op
must provide correct control dependencies between regions to guarantee
correctness. Regions in ParallelExecute may include Resource ops.
In the case where different regions include ops access the same resource,
the users of the ParallelExecute op must provide mechanism (via send/recvs
or via control dependencies) to guarantee correct ordering. Sequential
ordering of ops within a region is guaranteed. Also, sequential ordering of
ops before/after ParallelExecute ops are guaranteed. That is, execution of
regions inside ParallelExecute op is blocked until all inputs to all regions
are materialized and ops following ParallelExecute op are blocked until all
regions are executed.
}];
let results = (outs
Variadic<AnyType>:$execute_outputs
);
let regions = (region VariadicRegion<SizedRegion<1>>:$regions);
let extraClassDeclaration = [{
Block& GetRegionBlockWithIndex(unsigned index);
Operation::result_range GetRegionOutputs(unsigned region_index);
// Checks if a tf_device.parallel_execute index'th region block wraps a
// single operation and the single operation results are perfectly forwarded
// to the region block's return.
bool RegionWrapsSingleOp(unsigned index);
}];
let builders = [
OpBuilder<"Builder* builder, OperationState& state, int num_regions,"
"llvm::ArrayRef<Type> output_types">,
];
let verifier = [{ return Verify(*this); }];
}
def TfDevice_ReplicateOp :
TfDevice_Op<"replicate", [SingleBlockImplicitTerminator<"ReturnOp">]> {
let summary = "Wraps an N-way replicated computation.";

View File

@ -208,7 +208,7 @@ static Type InferReductionOpType(Value input, Value reduction_indices,
int64_t num_reduce_dim = 0;
llvm::SmallVector<bool, 4> is_reduce_dim(rank, false);
for (APInt index : indices.getValues<APInt>()) {
for (const APInt &index : indices.getValues<APInt>()) {
int64_t dim = GetDimForAxis(index.getSExtValue(), rank);
// Invalid input.
if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty);
@ -404,11 +404,11 @@ static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
if (perm0.getNumElements() != perm1.getNumElements()) return false;
SmallVector<int64_t, 8> perm0_values;
for (auto value : perm0.getIntValues())
for (const auto &value : perm0.getIntValues())
perm0_values.push_back(value.getSExtValue());
SmallVector<int64_t, 8> perm1_values;
for (auto value : perm1.getIntValues())
for (const auto &value : perm1.getIntValues())
perm1_values.push_back(value.getSExtValue());
for (int i = 0; i < perm0_values.size(); ++i) {
@ -2548,12 +2548,15 @@ static LogicalResult Verify(SizeOp op) {
// SliceOp
//===----------------------------------------------------------------------===//
// Verifies that,
// Verifies that:
//
// - operands begin and size are 1D with the same number of elements.
// - if the input is a ranked tensor, the rank of the input equals the number
// of elements in operands begin and size.
// - if begin are constants, 0 <= begin[i] < input_ty.getShape()[i]
// - if begin are constants, that
// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i]
// - if begins aren't constant but the input is a ranked tensor, that
// size[i] <= input_ty.getShape()[i]
//
static LogicalResult Verify(SliceOp op) {
RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
@ -2587,7 +2590,7 @@ static LogicalResult Verify(SliceOp op) {
bool constant_slice_sizes =
matchPattern(op.size(), m_Constant(&slice_sizes));
int dim = 0;
for (APInt raw_begin_index : begin_indices.getValues<APInt>()) {
for (const APInt &raw_begin_index : begin_indices.getValues<APInt>()) {
int64_t begin_index = raw_begin_index.getSExtValue();
int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
int64_t slice_size = constant_slice_sizes
@ -2603,6 +2606,20 @@ static LogicalResult Verify(SliceOp op) {
}
++dim;
}
} else if (input_ty) {
// If the inputs are ranked, we can do a few more sanity checks.
DenseIntElementsAttr slice_sizes;
if (matchPattern(op.size(), m_Constant(&slice_sizes))) {
auto input_shape = input_ty.getShape();
for (int64_t i = 0; i < input_ty.getRank(); ++i) {
int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
int64_t input_size = input_shape[i];
if (slice_size != -1 && input_size != -1 && slice_size > input_size) {
return op.emitOpError() << "requires size[i] <= Di, even if begin[i] "
"is unknown at compile time";
}
}
}
}
return success();
@ -3340,7 +3357,7 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x,
x_type.getDimSize((*attr_shape.begin()).getSExtValue()));
} else {
const_shape.reserve(attr_shape.getNumElements());
for (auto dim : attr_shape)
for (const auto &dim : attr_shape)
const_shape.push_back(x_type.getDimSize(dim.getSExtValue()));
}
return TransposeOp::build(

View File

@ -27,6 +27,7 @@ module {
// CHECK-LABEL: func @tpu0_func
// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
// CHECK-SAME: sym_visibility = "private"
// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]])
// CHECK: return %[[TPU0_FUNC_B_OUTPUT]]
}

View File

@ -93,7 +93,7 @@ library {
# CHECK: return
# CHECK: func @test_func_name0
# CHECK-SAME: tf.resource_arg_unique_id = 0
# CHECK-SAME tf.resource_arg_unique_id = 0
# CHECK-SAME: tf.resource_arg_unique_id = 0
# CHECK: tf_executor.graph
# CHECK: tf_executor.fetch
# CHECK: return

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -inline -mlir-disable-inline-simplify | FileCheck %s --dump-input=fail
// RUN: tf-opt %s -inline="disable-simplify" | FileCheck %s --dump-input=fail
// Test that simple TF operations can be inlined.

View File

@ -48,7 +48,7 @@ func @transpose_resnet_layer(%arg0: tensor<?x224x224x3xf32>, // input
} : (tensor<?x3x230x230xf32>, tensor<7x7x3x64xf32>) -> tensor<?x64x112x112xf32>
// CHECK: %[[CONV0:[0-9]*]] = "tf.Conv2D"
// CHECK-SAME %[[PAD]]
// CHECK-SAME: %[[PAD]]
// CHECK-SAME: data_format = "NHWC"
// CHECK-SAME: strides = [1, 2, 2, 1]

View File

@ -163,12 +163,12 @@ func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<
}
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = xla_hlo.pow %arg0, %arg0 : tensor<2xf32>
%0 = xla_hlo.power %arg0, %arg0 : tensor<2xf32>
return %0 : tensor<2xf32>
}
func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = xla_hlo.pow %arg0, %arg0 : tensor<?xf32>
%0 = xla_hlo.power %arg0, %arg0 : tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -184,7 +184,7 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.neg"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@ -203,7 +203,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.neg"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@ -461,32 +461,32 @@ func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
}
func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "xla_hlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "xla_hlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -551,17 +551,17 @@ func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
}
func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "xla_hlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -577,17 +577,17 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
}
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.sin"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "xla_hlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -677,6 +677,11 @@ func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
return %0 : tensor<i64>
}
func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
return %0 : tensor<3xcomplex<f32>>
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @biasAdd_NHWC(
@ -1481,3 +1486,10 @@ func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
// CHECK: [[VAL_366:%.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
// CHECK: return [[VAL_366]] : tensor<i64>
// CHECK: }
// CHECK-LABEL: func @complex(
// CHECK-SAME: [[VAL_367:%.*]]: tensor<3xf32>, [[VAL_368:%.*]]: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
// CHECK: [[VAL_369:%.*]] = "tf.Complex"([[VAL_367]], [[VAL_368]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
// CHECK: return [[VAL_369]] : tensor<3xcomplex<f32>>
// CHECK: }

View File

@ -7,6 +7,7 @@ glob_lit_tests(
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"preserve-entry-func-names.mlir": ["nomac"], # TODO(b/148403706): flaky on Mac, to be fixed.
"tf_add.mlir": ["nomac"], # TODO(b/148403706): flaky on Mac, to be fixed.
},
test_file_exts = ["mlir"],
)

View File

@ -1,4 +1,4 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure
func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} {

View File

@ -116,8 +116,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @shape_from_while_to_cond_body_functions
func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<4xf32> {
// CHECK "tf.While"
// CHECK-SAME (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>)
// CHECK: "tf.While"
// CHECK-SAME: (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>)
%0:3 = "tf.While"(%arg0, %arg1, %arg2) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>)
return %0#0 : tensor<4xf32>
}

View File

@ -1682,6 +1682,23 @@ func @testSlice_begin_out_of_bound(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// -----
func @testSlice_unknown_begin_out_of_bounds(%arg0: tensor<4xi32>, %begins: tensor<1xi64>) -> tensor<3xi32> {
%sizes = "tf.Const"() {value = dense<[5]> : tensor<1xi64>} : () -> (tensor<1xi64>)
// expected-error @+1 {{requires size[i] <= Di, even if begin[i] is unknown at compile time}}
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @testSlice_unknown_begin_in_bounds(%arg0: tensor<4xi32>, %begins: tensor<1xi64>) -> tensor<3xi32> {
%sizes = "tf.Const"() {value = dense<[4]> : tensor<1xi64>} : () -> (tensor<1xi64>)
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
// Valid StridedSlice operation.
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>

View File

@ -188,7 +188,7 @@ func @parallel_execute_single_region() {
// Check that a parallel_execute op with empty regions are not allowed.
func @parallel_execute_empty_region() {
"tf_device.parallel_execute"() ( {
// expected-error@-1 {{'tf_device.parallel_execute' op regions must not be empty. Found an empty region (0).}}
// expected-error@-1 {{'tf_device.parallel_execute' op region #0 ('regions') failed to verify constraint: region with 1 blocks}}
},
{
tf_device.return

View File

@ -51,12 +51,12 @@ module attributes {tf_saved_model.semantics} {
// Test case: Delete recursively dead cycle.
// CHECK-NOT func @recursively_dead0
// CHECK-NOT: func @recursively_dead0
func @recursively_dead0() {
"some_dialect.call"() { callee = @recursively_dead1 } : () -> ()
return
}
// CHECK-NOT func @recursively_dead1
// CHECK-NOT: func @recursively_dead1
func @recursively_dead1() {
"some_dialect.call"() { callee = @recursively_dead0 } : () -> ()
return

View File

@ -86,3 +86,21 @@ module attributes {tf_saved_model.semantics} {
}
}
// -----
module attributes {tf_saved_model.semantics} {
// CHECK-NOT: tf_saved_model.global_tensor
"tf_saved_model.global_tensor"() {sym_name = "v", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
"tf_saved_model.global_tensor"() {sym_name = "v2", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
func @f(%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @"v"}, %arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @"v2"})
attributes {tf_saved_model.exported_names = ["f"]} {
// CHECK: "tf.Const"()
%0 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: "tf.Const"()
%1 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
return
}
}

Some files were not shown because too many files have changed in this diff Show More