Merge branch 'master' of https://github.com/tensorflow/tensorflow into TF_GetName

This commit is contained in:
Daniel Nguyen 2020-07-21 19:54:17 +00:00
commit 7a113f40d9
342 changed files with 9271 additions and 11307 deletions

View File

@ -84,7 +84,8 @@
# release_gpu_common: Common options for GPU builds on Linux and Windows.
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux PU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# Allow builds using libc++ as a linker library
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
@ -570,3 +571,6 @@ build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
build:release_gpu_linux --config=release_gpu_common
build:release_gpu_linux --config=avx_linux
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_cpu_windows --config=release_common
build:release_cpu_windows --announce_rc

View File

@ -38,6 +38,7 @@
* Calling ops with a python constants or numpy values is now consistent with
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
truncating inputs such as from int64 to int32.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
* `tf.data`:
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
@ -53,7 +54,8 @@
* `tf.function`/AutoGraph:
* <ADD RELEASE NOTES HERE>
* `tf.lite`:
* <ADD RELEASE NOTES HERE>
* Better support for ops with high-dimensional broadcasting inputs by adding
`BroadcastTo` ops when necessary.
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
@ -65,9 +67,9 @@
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* Other:
* We have replaced uses of "whitelist" with "allowlist" where possible.
Please see https://developers.google.com/style/word-list#blacklist for more
context.
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more context.
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors

View File

@ -532,16 +532,14 @@ selects.config_setting_group(
package_group(
name = "internal",
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
],
)

View File

@ -240,6 +240,8 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
@ -308,6 +310,8 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/util:abstract_stack_trace",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -33,6 +33,7 @@ limitations under the License.
using tensorflow::dyn_cast;
using tensorflow::string;
using tensorflow::gtl::ArraySlice;
namespace tensorflow {
namespace tracing {
@ -138,20 +139,23 @@ class GraphOperation : public TracingOperation {
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override {
return tensorflow::errors::Unimplemented(
"SetAttrString has not been implemented yet.");
tensorflow::StringPiece s(data, length);
op_->node_builder.Attr(attr_name, s);
return Status::OK();
}
Status SetAttrInt(const char* attr_name, int64_t value) override {
return tensorflow::errors::Unimplemented(
"SetAttrInt has not been implemented yet.");
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
op_->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
return Status::OK();
}
Status SetAttrFloat(const char* attr_name, float value) override {
return tensorflow::errors::Unimplemented(
"SetAttrFloat has not been implemented yet.");
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrBool(const char* attr_name, bool value) override {
return tensorflow::errors::Unimplemented(
"SetAttrBool has not been implemented yet.");
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrType(const char* const attr_name, DataType value) override {
if (!op_) {
@ -164,8 +168,15 @@ class GraphOperation : public TracingOperation {
}
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override {
return tensorflow::errors::Unimplemented(
"SetAttrShape has not been implemented yet.");
PartialTensorShape shape;
if (num_dims >= 0) {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
}
op_->node_builder.Attr(attr_name, shape);
return Status::OK();
}
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override {
@ -174,8 +185,10 @@ class GraphOperation : public TracingOperation {
}
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented yet.");
tensorflow::NameAttrList func_name;
func_name.set_name(string(value, value + length));
op_->node_builder.Attr(attr_name, func_name);
return Status::OK();
}
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override {
@ -184,33 +197,71 @@ class GraphOperation : public TracingOperation {
}
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrStringList has not been implemented yet.");
if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
op_->colocation_constraints.clear();
for (int i = 0; i < num_values; ++i) {
op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
lengths[i]);
}
} else {
std::vector<tensorflow::StringPiece> v;
v.reserve(num_values);
for (int i = 0; i < num_values; ++i) {
v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
}
op_->node_builder.Attr(attr_name, v);
}
return Status::OK();
}
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrFloatList has not been implemented yet.");
op_->node_builder.Attr(attr_name,
ArraySlice<const float>(values, num_values));
return Status::OK();
}
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrIntList has not been implemented yet.");
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
op_->node_builder.Attr(
attr_name,
ArraySlice<const tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(values), num_values));
return Status::OK();
}
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrTypeList has not been implemented yet.");
op_->node_builder.Attr(attr_name,
ArraySlice<const DataType>(values, num_values));
return Status::OK();
}
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrBoolList has not been implemented yet.");
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
op_->node_builder.Attr(attr_name,
ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
}
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrShapeList has not been implemented yet.");
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_values);
for (int i = 0; i < num_values; ++i) {
if (num_dims[i] < 0) {
shapes.emplace_back();
} else {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
shapes.emplace_back(ArraySlice<tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
}
}
op_->node_builder.Attr(attr_name, shapes);
return Status::OK();
}
Status SetAttrFunctionList(
const char* attr_name,

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
@ -42,55 +44,10 @@ class CppGradients
}
};
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(identity_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
return Status::OK();
}
// =================== Register gradients for Add ============================
class AddGradientFunction : public GradientFunction {
public:
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id1"));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
~AddGradientFunction() override {}
private:
AbstractContext* ctx_;
};
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction(op.ctx);
}
Status RegisterGradients(GradientRegistry* registry) {
return registry->Register("Add", AddRegisterer);
}
// =================== End gradient registrations ============================
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/util/abstract_stack_trace.h"
struct TFE_Op;
@ -44,6 +46,12 @@ class ImmediateExecutionOperation : public AbstractOperation {
// Experimental
virtual Status SetUseXla(bool enable) = 0;
// Set stack trace to be used for potential async error reporting.
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
// Returns the stack trace set by `SetStackTrace` if exists.
virtual absl::optional<AbstractStackTrace> GetStackTrace() = 0;
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;

View File

@ -25,7 +25,9 @@ cc_library(
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":expiring_lru_cache",
":gcs_helper",
":ram_file_block_cache",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",

View File

@ -28,6 +28,27 @@ limitations under the License.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// The environment variable that overrides the block size for aligned reads from
// GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes).
constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB";
constexpr size_t kDefaultBlockSize = 64 * 1024 * 1024;
// The environment variable that overrides the max size of the LRU cache of
// blocks read from GCS. Specified in MB.
constexpr char kMaxCacheSize[] = "GCS_READ_CACHE_MAX_SIZE_MB";
constexpr size_t kDefaultMaxCacheSize = 0;
// The environment variable that overrides the maximum staleness of cached file
// contents. Once any block of a file reaches this staleness, all cached blocks
// will be evicted on the next read.
constexpr char kMaxStaleness[] = "GCS_READ_CACHE_MAX_STALENESS";
constexpr uint64_t kDefaultMaxStaleness = 0;
constexpr char kStatCacheMaxAge[] = "GCS_STAT_CACHE_MAX_AGE";
constexpr uint64_t kStatCacheDefaultMaxAge = 5;
// The environment variable that overrides the maximum number of entries in the
// Stat cache.
constexpr char kStatCacheMaxEntries[] = "GCS_STAT_CACHE_MAX_ENTRIES";
constexpr size_t kStatCacheDefaultMaxEntries = 1024;
// How to upload new data when Flush() is called multiple times.
// By default the entire file is reuploaded.
constexpr char kAppendMode[] = "GCS_APPEND_MODE";
@ -82,28 +103,15 @@ static void MaybeAppendSlash(std::string* name) {
name->push_back('/');
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
typedef struct GCSFile {
const std::string bucket;
const std::string object;
gcs::Client* gcs_client; // not owned
} GCSFile;
void Cleanup(TF_RandomAccessFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// TODO(vnvo2409): Adding cache.
// `google-cloud-cpp` is working on a feature that we may want to use.
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
auto stream = gcs_file->gcs_client->ReadObject(
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
// A helper function to actually read the data from GCS.
static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
size_t buffer_size, char* buffer,
gcs::Client* gcs_client, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return -1;
auto stream = gcs_client->ReadObject(
bucket, object, gcs::ReadRange(offset, offset + buffer_size));
TF_SetStatusFromGCSStatus(stream.status(), status);
if ((TF_GetCode(status) != TF_OK) &&
(TF_GetCode(status) != TF_OUT_OF_RANGE)) {
@ -115,11 +123,92 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
if (read != n) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
}
stream.read(buffer, read);
return read;
return stream.gcount();
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
using ReadFn =
std::function<int64_t(const std::string& path, uint64_t offset, size_t n,
char* buffer, TF_Status* status)>;
typedef struct GCSFile {
const std::string path;
const bool is_cache_enable;
const uint64_t buffer_size;
ReadFn read_fn;
absl::Mutex buffer_mutex;
uint64_t buffer_start ABSL_GUARDED_BY(buffer_mutex);
bool buffer_end_is_past_eof ABSL_GUARDED_BY(buffer_mutex);
std::string buffer ABSL_GUARDED_BY(buffer_mutex);
GCSFile(std::string path, bool is_cache_enable, uint64_t buffer_size,
ReadFn read_fn)
: path(path),
is_cache_enable(is_cache_enable),
buffer_size(buffer_size),
read_fn(std::move(read_fn)),
buffer_mutex(),
buffer_start(0),
buffer_end_is_past_eof(false),
buffer() {}
} GCSFile;
void Cleanup(TF_RandomAccessFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// `google-cloud-cpp` is working on a feature that we may want to use.
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (gcs_file->is_cache_enable || n > gcs_file->buffer_size) {
return gcs_file->read_fn(gcs_file->path, offset, n, buffer, status);
} else {
absl::MutexLock l(&gcs_file->buffer_mutex);
size_t buffer_end = gcs_file->buffer_start + gcs_file->buffer.size();
size_t copy_size = 0;
if (offset < buffer_end && gcs_file->buffer_start) {
copy_size = (std::min)(n, static_cast<size_t>(buffer_end - offset));
memcpy(buffer,
gcs_file->buffer.data() + (offset - gcs_file->buffer_start),
copy_size);
}
bool consumed_buffer_to_eof =
offset + copy_size >= buffer_end && gcs_file->buffer_end_is_past_eof;
if (copy_size < n && !consumed_buffer_to_eof) {
gcs_file->buffer_start = offset + copy_size;
gcs_file->buffer.resize(gcs_file->buffer_size);
auto read_fill_buffer = gcs_file->read_fn(
gcs_file->path, gcs_file->buffer_start, gcs_file->buffer_size,
&(gcs_file->buffer[0]), status);
gcs_file->buffer_end_is_past_eof =
(TF_GetCode(status) == TF_OUT_OF_RANGE);
if (read_fill_buffer >= 0) gcs_file->buffer.resize(read_fill_buffer);
if (TF_GetCode(status) != TF_OK &&
TF_GetCode(status) != TF_OUT_OF_RANGE) {
// Empty the buffer to avoid caching bad reads.
gcs_file->buffer.resize(0);
return -1;
}
size_t remaining_copy =
(std::min)(n - copy_size, gcs_file->buffer.size());
memcpy(buffer + copy_size, gcs_file->buffer.data(), remaining_copy);
copy_size += remaining_copy;
if (copy_size < n) {
// Forget the end-of-file flag to allow for clients that poll on the
// same file.
gcs_file->buffer_end_is_past_eof = false;
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
return copy_size;
}
}
TF_SetStatus(status, TF_OK, "");
return copy_size;
}
}
} // namespace tf_random_access_file
@ -290,11 +379,53 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region) {
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
// TODO(vnvo2409): Use partial reponse for better performance.
// TODO(vnvo2409): We could do some cleanups like `return TF_SetStatus`.
// TODO(vnvo2409): Refactor the filesystem implementation when
// https://github.com/googleapis/google-cloud-cpp/issues/4482 is done.
GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
: gcs_client(gcs_client), block_cache_lock() {
const char* append_mode = std::getenv(kAppendMode);
compose = (append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
uint64_t value;
block_size = kDefaultBlockSize;
size_t max_bytes = kDefaultMaxCacheSize;
uint64_t max_staleness = kDefaultMaxStaleness;
// Apply the overrides for the block size (MB), max bytes (MB), and max
// staleness (seconds) if provided.
if (absl::SimpleAtoi(std::getenv(kBlockSize), &value)) {
block_size = value * 1024 * 1024;
}
if (absl::SimpleAtoi(std::getenv(kMaxCacheSize), &value)) {
max_bytes = static_cast<size_t>(value * 1024 * 1024);
}
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
max_staleness = value;
}
auto gcs_client_ptr = &this->gcs_client;
file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness,
[gcs_client_ptr](const std::string& filename, size_t offset,
size_t buffer_size, char* buffer, TF_Status* status) {
return LoadBufferFromGCS(filename, offset, buffer_size, buffer,
gcs_client_ptr, status);
});
uint64_t stat_cache_max_age = kStatCacheDefaultMaxAge;
size_t stat_cache_max_entries = kStatCacheDefaultMaxEntries;
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxAge), &value)) {
stat_cache_max_age = value;
}
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxEntries), &value)) {
stat_cache_max_entries = static_cast<size_t>(value);
}
stat_cache = std::make_unique<ExpiringLRUCache<GcsFileStat>>(
stat_cache_max_age, stat_cache_max_entries);
}
void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
@ -303,12 +434,7 @@ void Init(TF_Filesystem* filesystem, TF_Status* status) {
return;
}
const char* append_mode = std::getenv(kAppendMode);
bool compose =
(append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
filesystem->plugin_filesystem =
new GCSFile({std::move(client.value()), compose});
filesystem->plugin_filesystem = new GCSFile(std::move(client.value()));
TF_SetStatus(status, TF_OK, "");
}
@ -325,8 +451,32 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
bool is_cache_enabled;
{
absl::MutexLock l(&gcs_file->block_cache_lock);
is_cache_enabled = gcs_file->file_block_cache->IsCacheEnabled();
}
auto read_fn = [gcs_file, is_cache_enabled](
const std::string& path, uint64_t offset, size_t n,
char* buffer, TF_Status* status) -> int64_t {
// TODO(vnvo2409): Check for `stat_cache`.
int64_t read = 0;
if (is_cache_enabled) {
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
} else {
read = LoadBufferFromGCS(path, offset, n, buffer, &gcs_file->gcs_client,
status);
}
if (TF_GetCode(status) != TF_OK) return -1;
if (read < n)
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
else
TF_SetStatus(status, TF_OK, "");
return read;
};
file->plugin_file = new tf_random_access_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client});
std::move(path), is_cache_enabled, gcs_file->block_size, read_fn);
TF_SetStatus(status, TF_OK, "");
}

View File

@ -17,6 +17,8 @@
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
#include "tensorflow/c/tf_status.h"
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
@ -45,10 +47,23 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region);
} // namespace tf_read_only_memory_region
namespace tf_gcs_filesystem {
typedef struct GcsFileStat {
TF_FileStatistics base;
int64_t generation_number;
} GcsFileStat;
typedef struct GCSFile {
google::cloud::storage::Client gcs_client; // owned
bool compose;
absl::Mutex block_cache_lock;
std::shared_ptr<RamFileBlockCache> file_block_cache
ABSL_GUARDED_BY(block_cache_lock);
uint64_t block_size; // Reads smaller than block_size will trigger a read
// of block_size.
std::unique_ptr<ExpiringLRUCache<GcsFileStat>> stat_cache;
GCSFile(google::cloud::storage::Client&& gcs_client);
} GCSFile;
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <aws/core/config/AWSProfileConfigLoader.h>
#include <aws/core/utils/FileSystemUtils.h>
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/s3/model/CopyObjectRequest.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadBucketRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
@ -545,7 +546,10 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
GetS3Client(s3_file);
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
// We need to delete `file->plugin_file` in case of errors.
// We need to delete `file->plugin_file` in case of errors. We set
// `file->plugin_file` to `nullptr` in order to avoid segment fault when
// calling deleter of `unique_ptr`.
file->plugin_file = nullptr;
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile*)> writer(
file, [](TF_WritableFile* file) {
if (file != nullptr && file->plugin_file != nullptr) {
@ -561,10 +565,14 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
tf_random_access_file::Cleanup(file);
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
// We set `reader->plugin_file` to `nullptr` in order to avoid segment fault
// when calling deleter of `unique_ptr`
reader->plugin_file = nullptr;
NewRandomAccessFile(filesystem, path, reader.get(), status);
if (TF_GetCode(status) != TF_OK) return;
@ -678,7 +686,7 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
@ -695,10 +703,14 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
tf_random_access_file::Cleanup(file);
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
// We set `reader->plugin_file` to `nullptr` in order to avoid segment fault
// when calling deleter of `unique_ptr`
reader->plugin_file = nullptr;
NewRandomAccessFile(filesystem, path, reader.get(), status);
if (TF_GetCode(status) != TF_OK) return;
auto read =
@ -710,6 +722,67 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
TF_SetStatus(status, TF_OK, "");
}
static void SimpleCopyFile(const Aws::String& source,
const Aws::String& bucket_dst,
const Aws::String& object_dst, S3File* s3_file,
TF_Status* status) {
Aws::S3::Model::CopyObjectRequest copy_object_request;
copy_object_request.WithCopySource(source)
.WithBucket(bucket_dst)
.WithKey(object_dst);
auto copy_object_outcome =
s3_file->s3_client->CopyObject(copy_object_request);
if (!copy_object_outcome.IsSuccess())
TF_SetStatusFromAWSError(copy_object_outcome.GetError(), status);
else
TF_SetStatus(status, TF_OK, "");
};
static void MultiPartCopy(const Aws::String& source,
const Aws::String& bucket_dst,
const Aws::String& object_dst, const size_t num_parts,
const uint64_t file_size, S3File* s3_file,
TF_Status* status){};
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_Status* status) {
auto file_size = GetFileSize(filesystem, src, status);
if (TF_GetCode(status) != TF_OK) return;
if (file_size == 0)
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Source is a directory or empty file");
Aws::String bucket_src, object_src;
ParseS3Path(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
Aws::String copy_src = bucket_src + "/" + object_src;
Aws::String bucket_dst, object_dst;
ParseS3Path(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
auto chunk_size =
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
size_t num_parts = 1;
if (file_size > chunk_size) num_parts = ceil((float)file_size / chunk_size);
if (num_parts == 1)
SimpleCopyFile(copy_src, bucket_dst, object_dst, s3_file, status);
else if (num_parts > 10000)
TF_SetStatus(
status, TF_UNIMPLEMENTED,
absl::StrCat("MultiPartCopy with number of parts more than 10000 is "
"not supported. Your object ",
src, " required ", num_parts,
" as multi_part_copy_part_size is set to ", chunk_size,
". You can control this part size using the environment "
"variable S3_MULTI_PART_COPY_PART_SIZE to increase it.")
.c_str());
else
MultiPartCopy(copy_src, bucket_dst, object_dst, num_parts, file_size,
s3_file, status);
}
// TODO(vnvo2409): Implement later
} // namespace tf_s3_filesystem

View File

@ -0,0 +1,23 @@
# Library of gradient functions.
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "math_grad",
srcs = ["math_grad.cc"],
hdrs = [
"math_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
using tensorflow::ops::Identity;
namespace tensorflow {
namespace gradients {
namespace {
class AddGradientFunction : public GradientFunction {
public:
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting.
TF_RETURN_IF_ERROR(ops::Identity(
ctx_, {grad_inputs[0]}, absl::MakeSpan(identity_outputs), "Identity0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(ops::Identity(
ctx_, {grad_inputs[0]}, absl::MakeSpan(identity_outputs), "Identity1"));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
~AddGradientFunction() override {}
private:
AbstractContext* ctx_;
};
} // namespace
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction(op.ctx);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -12,34 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
#include <limits>
#include "absl/random/random.h"
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace {
static int64 overridden_node_id = -1;
} // namespace
namespace internal {
void OverrideNodeIdForTesting(const int64 node_id) {
overridden_node_id = node_id;
}
uint64 GetNodeId() {
if (overridden_node_id > -1) {
return overridden_node_id;
} else {
return absl::Uniform(absl::SharedBitGen(), uint64{0},
std::numeric_limits<uint64>::max());
}
}
} // namespace internal
namespace gradients {
GradientFunction* AddRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -0,0 +1,24 @@
# Experimental ops. These will eventually be replaced by machine-generated versions.
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "array_ops",
srcs = [
"array_ops.cc",
],
hdrs = [
"array_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(identity_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1;
return identity_op->Execute(outputs, &num_retvals);
}
} // namespace ops
} // namespace tensorflow

View File

@ -12,27 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
#include "tensorflow/core/framework/types.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
// Implementation details of distributed_tpu_rewrite_pass.cc, please DO NOT
// depend on these.
namespace internal {
// When set to a value >= 0, overrides the node_id. Used for getting
// deterministic node_ids during testing.
void OverrideNodeIdForTesting(int64 node_id);
// Retrieves the node id, used to make some node names unique in the rewrite
// pass.
uint64 GetNodeId();
} // namespace internal
namespace ops {
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_

View File

@ -111,10 +111,6 @@ TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory(
TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority(
TF_KernelBuilder* kernel_builder, int32_t priority_number);
typedef struct string_view string_view;
TF_CAPI_EXPORT extern string_view TF_GetName(TF_KernelBuilder* kernel_builder);
// Register the given kernel builder with the TensorFlow runtime. If
// registration fails, the given status will be populated.
//

View File

@ -53,7 +53,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include <iostream>
struct MyCustomKernel {
bool created;
bool compute_called;
@ -162,9 +161,6 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
ASSERT_TRUE(delete_called);
}
TEST(TestKernel, TestGetKernelName) {
}
class DummyDevice : public DeviceBase {
public:
explicit DummyDevice(Env* env) : DeviceBase(env) {}

View File

@ -1829,7 +1829,7 @@ TEST(XlaCompilationTest, XLALiteAllowlist) {
}
EXPECT_TRUE(unknow_op.empty())
<< "Someone added support for a new TF opeations inside XLA. They must "
"be included in the XLALite allowlist or blacklist:\n"
"be included in the XLALite allowlist or denylist:\n"
<< absl::StrJoin(unknow_op, "\n");
}
} // namespace

View File

@ -74,7 +74,6 @@ We have several choices on how to lower the host-side part from LHLO:
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
TFRT ops are interpreted by C++ code.
* (Con) host side is under development and not tested.
* (Con) the JAX integration isnt clear from a runtime point of view
* Jitted CPU code
* (Pro) great lower-ability. Create a few loops and conditions and it's
done.
@ -84,8 +83,7 @@ We have several choices on how to lower the host-side part from LHLO:
dynamic loading, etc).
* Existing (interpreting) XLA runtime
Tentative conclusion: Use jitted CPU code during the transition, and optionally
adopt TFRT in the end.
Decision: adopt TFRT, but also support jitting CPU code in TFRT.
## Migrating Device LLVM IR (Task 3)
@ -114,7 +112,7 @@ end state of each XLA op:
* (Cost) Will be throw-away work if we want to ultimately migrate to
Standard.
* (Benefit) It is easy and mechanical. Can be done in a short period.
* (Benefit) It doesn't benefit more compared to a).
* (Benefit) It doesn't benefit more compared to (1).
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
* (Cost) Lifting existing emitters to Standard introduces some challenges.
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
@ -134,6 +132,19 @@ end state of each XLA op:
* (Benefit) unified stack; community support; portability; more
optimization potentials.
Conclusions:
* Don't go for (2). (1) or (3) are just better than (2). (2) costs more than
(1), since it requires a lot of mechanical refactoring. With (1) we can
still achieve the goal of enabling XLA to pick up MLIR emitters. This is by
doing LHLO -> LLVM IR -> run legacy device emitters.
* ElementalIrEmitter ops go for (4), but not incrementally. There is no way to
do it op by op, because all elementally-emitted ops are connected into the
same graph. This work can also serve as a unification point of several
on-going forces (xla/service/mlir\_gpu, the kernel generator, Linalg).
* All other ops go for (1). As a stretch goal, they might be migrated to (3)
or (4).
## Prioritization
While all three tasks mentioned above are parallelizable, under limited
@ -210,26 +221,19 @@ The exact profiling can't be easily done for MLIR-generated ops, since:
### Step 3: (Task 2) Migrating Thunks
This step migrates all host ops and library calls. This step will eliminate most
of the thunks and produce serializable MLIR instead.
There are roughly three kinds of thunks:
As a note, there are roughly three kinds of thunks:
* KernelThunk, which launches a kernel.
* Control flow thunks, which has host control flow logic (conditional, while,
for, sequence) and launch body kernels.
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
The **bottom line** is to:
The plan is:
* Make Thunks (de)serializable.
* Help improve TFRT to a state where it can support these semantics.
* As the state improves, migrate individual thunks incrementally.
* Create a Thunk dialect that provides (de)serialize logic for all existing
C++-based Thunks.
* Change emitters to emit a graph of Thunk dialect.
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
step requires profiling and stream support.
These action items are only partially ordered. The actual execution order /
engineering parallelism is to be evaluated as it goes.
### Step 4: (Task 3) Migrated ElementalIrEmitter

View File

@ -568,6 +568,26 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "legalize_gather_to_torch_index_select",
srcs = ["lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc"],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
],
deps = [
":hlo",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
cc_library(
name = "legalize_tanh_to_approximation",
srcs = ["lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc"],
@ -717,6 +737,7 @@ cc_library(
":hlo_dialect_registration",
":hlo_legalize_to_lhlo",
":legalize_control_flow",
":legalize_gather_to_torch_index_select",
":legalize_tanh_to_approximation",
":legalize_to_linalg",
":legalize_to_standard",

View File

@ -41,6 +41,10 @@ void PopulateComplexLoweringPatterns(MLIRContext *context,
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
// Rewrite patterns for gather to equivalent torch index select legalization.
void PopulateGatherToTorchIndexSelectPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns);
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);

View File

@ -1468,7 +1468,7 @@ static LogicalResult Verify(PadOp op) {
static LogicalResult Verify(ReshapeOp op) {
// If the operand type is dynamically shaped there is nothing to verify.
auto operand_ty = op.operand().getType().cast<RankedTensorType>();
auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_ty || !operand_ty.hasStaticShape()) return success();
// If the operand type is statically shaped (not required) the number of

View File

@ -0,0 +1,152 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/memory/memory.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace mhlo {
namespace {
struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
auto start_indices = gather.start_indices();
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
if (!start_indices_ty.hasRank()) {
return failure();
}
auto operand = gather.operand();
auto operand_ty = operand.getType().cast<ShapedType>();
if (!operand_ty.hasRank()) {
return failure();
}
int64_t index_vector_dim =
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
// We can use torch_index_select if the last dimension represents the
// gather indices.
auto dimension_numbers = gather.dimension_numbers();
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
index_vector_dim) {
return failure();
}
// Index select only works across a single dimension.
if (!start_indices_ty.getShape().empty() &&
start_indices_ty.getShape().back() != 1) {
return failure();
}
// Only support the default case for start_index_map.
if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
dimension_numbers.start_index_map()
.getValue(0)
.cast<IntegerAttr>()
.getValue() != 0) {
return failure();
}
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_ty) {
return failure();
}
// Offset dimensions should be the defaults.
if (dimension_numbers.offset_dims().getType().getNumElements() !=
result_ty.getRank() - index_vector_dim) {
return failure();
}
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
if ((it.index() + index_vector_dim) != it.value()) {
return failure();
}
}
for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
// First shape value must be 1.
if (it.index() == 0) {
if (it.value().getSExtValue() != 1) {
return failure();
}
continue;
}
// The gather needs to index the entire slice for each other dimension.
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
return failure();
}
}
llvm::SmallVector<int64_t, 4> index_select_shape =
llvm::to_vector<4>(start_indices_ty.getShape());
for (auto dim : operand_ty.getShape().drop_front()) {
index_select_shape.push_back(dim);
}
if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
1 ||
dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
return failure();
}
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
gather.getLoc(),
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
torch_index_select);
return success();
}
};
struct LegalizeGatherToTorchIndexSelect
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns;
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
void PopulateGatherToTorchIndexSelectPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<GatherIsTorchIndexSelect>(context);
}
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass(
"mhlo-legalize-gather-to-torch-index-select",
"Legalizes gathers to a torch index select.");
} // namespace mhlo
} // namespace mlir

View File

@ -0,0 +1,41 @@
// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s
// CHECK-LABEL: @gather_to_index_select
func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> {
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
// CHECK-SAME: batch_dims = 0 : i64,
// CHECK-SAME: dim = 0 : i64
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32>
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32>
// CHECK: return [[RES]]
return %0 : tensor<1x3x4xf32>
}
// CHECK-LABEL: @scalar_gather_to_index_select
func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<i32>) -> tensor<1x4xf32> {
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
// CHECK-SAME: batch_dims = 0 : i64,
// CHECK-SAME: dim = 0 : i64
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<i32>) -> tensor<4xf32>
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<i32>) -> tensor<1x4xf32>
// CHECK: return [[RES]]
return %0 : tensor<1x4xf32>
}
// CHECK-LABEL: @gather_no_lowering_subslice
func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> {
// CHECK: "mhlo.gather"
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32>
return %0 : tensor<1x3x3xf32>
}
// CHECK-LABEL: @gather_no_lowering_multidim
func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> {
// CHECK: "mhlo.gather"
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32>
return %0 : tensor<1x3x4xf32>
}

View File

@ -220,18 +220,14 @@ cc_library(
],
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
@ -349,7 +345,6 @@ cc_library(
"transforms/passes.h",
],
deps = [
":common",
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",
@ -369,7 +364,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:tensor_list",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
@ -400,7 +394,6 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -434,7 +427,6 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
@ -457,7 +449,6 @@ cc_library(
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -609,8 +600,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
@ -651,7 +640,6 @@ cc_library(
":flatbuffer_tflite_operator_lib",
":tensorflow_lite",
":tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
@ -724,7 +712,6 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirTranslateMain",
@ -858,10 +845,8 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -151,10 +151,13 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
return errors::Unimplemented("Only support a single exported name.");
}
tensorflow::GraphImportConfig specs;
specs.upgrade_legacy = true;
TF_ASSIGN_OR_RETURN(auto module,
ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags,
exported_names, &context));
exported_names, specs, &context));
if (!model_flags.input_arrays().empty() ||
!model_flags.output_arrays().empty()) {

View File

@ -81,7 +81,6 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",

View File

@ -144,6 +144,10 @@ int main(int argc, char **argv) {
StatusOr<mlir::OwningModuleRef> module;
tensorflow::GraphImportConfig specs;
specs.upgrade_legacy = upgrade_legacy;
specs.prune_unused_nodes = true;
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
// inside mlir is done.
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
@ -168,12 +172,10 @@ int main(int argc, char **argv) {
return kTrFailure;
}
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
tags, exported_names, &context);
module =
tensorflow::ImportSavedModel(input_file_name, saved_model_version, tags,
exported_names, specs, &context);
} else {
tensorflow::GraphImportConfig specs;
specs.upgrade_legacy = upgrade_legacy;
specs.prune_unused_nodes = true;
module = tensorflow::LoadFromGraphdefOrMlirSource(
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
specs, debug_info_file, input_arrays, input_dtypes, input_shapes,

View File

@ -186,7 +186,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
const std::string& input_filename, const int saved_model_version,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
absl::Span<std::string> exported_names, const GraphImportConfig& specs,
mlir::MLIRContext* context) {
if (saved_model_version == 2) {
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
input_filename, tags, exported_names, context);
@ -194,7 +195,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
return module_or.ConsumeValueOrDie();
} else if (saved_model_version == 1) {
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
input_filename, tags, exported_names, context);
input_filename, tags, exported_names, context, specs.upgrade_legacy);
if (!module_or.status().ok()) return module_or.status();
return module_or.ConsumeValueOrDie();

View File

@ -48,7 +48,8 @@ LoadFromGraphdefOrMlirSource(
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
const std::string& input_filename, const int saved_model_version,
const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
absl::Span<std::string> exported_names, const GraphImportConfig& specs,
mlir::MLIRContext* context);
// Taking a MLIR module in TF executor dialect and a set of parameters,
// applies a set of passes to convert the module to TF Lite dialect and

View File

@ -37,22 +37,19 @@ class HasRankAtMost<int n> : Constraint<
// Multi-pattern consisting of matching stand-alone convolution op followed by
// activation op.
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
def : Pat<(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w)),
(TFL_Conv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w),
[(HasOneUse $conv_out)]>;
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier)),
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w,
$multiplier),
[(HasOneUse $conv_out)]>;
def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat<
(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor,
$w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)),
(TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w),
[(HasOneUse $conv_out)]>;
def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat<
(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor,
$w_factor, TFL_AF_None, $padding, $stride_h, $stride_w,
$multiplier)),
(TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor,
ActFnAttr, $padding, $stride_h, $stride_w, $multiplier),
[(HasOneUse $conv_out)]>;
}
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
@ -73,33 +70,29 @@ class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
// constant folding the bias and the binary op's constant operand. The following
// pattern restricts to float constant values for now.
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
def : Pat<(binaryOp (TFL_Conv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input, $filter,
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
(HasOneUse $output)]>;
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input, $filter,
(binaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier),
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
(HasOneUse $output)]>;
def FuseBinaryOpWithConv#binaryOp : Pat<
(binaryOp (TFL_Conv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor,
TFL_AF_None, $padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input, $filter,
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
(HasOneUse $output)]>;
def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
$stride_w, $multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input, $filter,
(binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
$multiplier),
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
(HasOneUse $output)]>;
}
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
@ -116,43 +109,43 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall<
// The following pattern restricts to float constant values for now.
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input,
(BinaryOp (ConstantOp $filter),
(ConstantOp
(ExpandTo4DForDepthwiseConv $value)),
TFL_AF_None),
(BinaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier),
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
(HasOneUse $output)]>;
def : Pat<(BinaryOp (TFL_Conv2DOp:$conv_output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input,
(BinaryOp (ConstantOp $filter),
(ConstantOp (ExpandTo4DForConv $value)),
TFL_AF_None),
(BinaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
(HasOneUse $conv_output)]>;
def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
$stride_w, $multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input,
(BinaryOp
(ConstantOp $filter),
(ConstantOp (ExpandTo4DForDepthwiseConv $value)),
TFL_AF_None),
(BinaryOp
(ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h,
$stride_w, $multiplier),
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
(HasOneUse $output)]>;
def FuseMulOrDivWithConv#BinaryOp : Pat<
(BinaryOp (TFL_Conv2DOp:$conv_output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input,
(BinaryOp (ConstantOp $filter),
(ConstantOp (ExpandTo4DForConv $value)),
TFL_AF_None),
(BinaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
(HasOneUse $conv_output)]>;
}
foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
@ -177,7 +170,7 @@ class OperandHasRank<int n> : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
// Matching HardSwish
def : Pat<
def MatchHardSwishPattern1 : Pat<
(TFL_MulOp
(TFL_MulOp
$x, (TFL_AddOp
@ -190,7 +183,7 @@ def : Pat<
(TFL_HardSwishOp $x),
[(EqualOperands $x, $y)]>;
def : Pat<
def MatchHardSwishPattern2 : Pat<
(TFL_MulOp
$x,
(TFL_MulOp
@ -207,7 +200,7 @@ def : Pat<
// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
// incorrect placement in the quantization aware training.
// TODO(b/149735743): We should make the placement automatically.
def : Pat<
def MatchHardSwishQuantized : Pat<
(TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
(TFL_MulOp
$x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
@ -238,7 +231,8 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
// This pattern constructs L2NormalizationOp from
// Mul->Rsqrt->Sum->Square Or
// Div->sqrt->Sum->Square
def : Pat<(FirstOp $operand1,
def L2NormalizePattern1#FirstOp#SecondOp : Pat<
(FirstOp $operand1,
(SecondOp
(TFL_SumOp
(TFL_SquareOp:$sq_op $square_operand),
@ -251,7 +245,8 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
// Below patterns for L2Normalize when there is an Add or Maximum
// adding or clamping to a small constant scalar.
def : Pat<(FirstOp $operand1,
def L2NormalizePattern2#FirstOp#SecondOp : Pat<
(FirstOp $operand1,
(SecondOp
(TFL_AddOp
(TFL_SumOp
@ -265,7 +260,8 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
(L2NormValidReduceIndex $sq_op, $axis),
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
def : Pat<(FirstOp $operand1,
def L2NormalizePattern3#FirstOp#SecondOp : Pat<
(FirstOp $operand1,
(SecondOp
(TFL_MaximumOp
(TFL_SumOp
@ -302,14 +298,16 @@ def HaveSameType : Constraint<CPred<"$0.getType(), $1.getType()">>;
// Pattern for skipping Tile if it is mainly for broadcasting and the
// Op is already supporting broadcasting.
multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
def : Pat<(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
$operand, $act_func),
(BinaryOp $input, $operand, $act_func),
def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat<
(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
$operand, $act_func),
(BinaryOp $input, $operand, $act_func),
[(OperandsBroadcastToOutputType $input, $operand, $result)]>;
def : Pat<(BinaryOp:$result $operand,
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
(BinaryOp $operand, $input, $act_func),
def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
(BinaryOp:$result $operand,
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
(BinaryOp $operand, $input, $act_func),
[(OperandsBroadcastToOutputType $operand, $input, $result)]>;
}
@ -318,9 +316,10 @@ multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
def : Pat<(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $binary_out)]>;
def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat<
(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $binary_out)]>;
}
}
@ -340,21 +339,22 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
// transformation, the shape of the binary op result is [40x1600], which
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
// make sure $rhs is the tail shape of $lhs.
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a), TFL_AF_None),
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have
// the same shape, which is defined by `shape`.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input),
(HasRankAtMost<5> $input),
(HasRankAtMost<5> $rhs)]>;
def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a), TFL_AF_None),
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have
// the same shape, which is defined by `shape`.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input),
(HasRankAtMost<5> $input),
(HasRankAtMost<5> $rhs)]>;
}
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
@ -370,19 +370,20 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
// transformation, the shape of the binary op result is [40x1600], which
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
// make sure $rhs is the tail shape of $lhs.
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a)),
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have
// the same shape, which is defined by `shape`.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input)]>;
def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a)),
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have
// the same shape, which is defined by `shape`.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input)]>;
}
// Reorder the element-wise value operations and the element move operations,
@ -392,9 +393,10 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp] in {
foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp,
TFL_ReshapeOp, TFL_TransposeOp] in {
def : Pat<(ValueOp:$value (MoveOp:$move $input, $move_def)),
(MoveOp (ValueOp $input), $move_def),
[(HasOneUse $move)]>;
def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat<
(ValueOp:$value (MoveOp:$move $input, $move_def)),
(MoveOp (ValueOp $input), $move_def),
[(HasOneUse $move)]>;
}
}
@ -403,16 +405,16 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
def GetShape: NativeCodeCall<"GetShape($0)">;
// Convert squeeze to reshape if possible.
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
(TFL_ReshapeOp $input,
(ConstantOp (GetShape $squeeze_op))),
[(AnyStaticShapeTensor $squeeze_op)]>;
def ConvertSqueezeToReshape : Pat<
(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
(TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))),
[(AnyStaticShapeTensor $squeeze_op)]>;
// Convert expand_dims to reshape if possible.
def : Pat<(TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
(TFL_ReshapeOp $input,
(ConstantOp (GetShape $expand_dims_op))),
[(AnyStaticShapeTensor $expand_dims_op)]>;
def ConvertExpandDimsToReshape : Pat<
(TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
(TFL_ReshapeOp $input, (ConstantOp (GetShape $expand_dims_op))),
[(AnyStaticShapeTensor $expand_dims_op)]>;
class FloatValueEquals<string val> : Constraint<CPred<
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
@ -420,25 +422,27 @@ class FloatValueEquals<string val> : Constraint<CPred<
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
// ReLU patterns
def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
(ConstantOp $NegOne)),
(ConstantOp $One)),
(TFL_Relu1Op $input),
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
def MatchRelu1Pattern1 : Pat<
(TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)),
(ConstantOp $One)),
(TFL_Relu1Op $input),
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
(ConstantOp $One)),
(ConstantOp $NegOne)),
(TFL_Relu1Op $input),
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
def MatchRelu1Pattern2 : Pat<
(TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)),
(ConstantOp $NegOne)),
(TFL_Relu1Op $input),
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
$input2),
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
[(ConstDoubleValueLessThan<"1"> $alpha),
(EqualOperands $input1, $input2),
(HasOneUse $mul_out)]>;
def MatchLeakyRelu : Pat<
(TFL_MaximumOp
(TFL_MulOp:$mul_out $input1,
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
$input2),
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
[(ConstDoubleValueLessThan<"1"> $alpha),
(EqualOperands $input1, $input2),
(HasOneUse $mul_out)]>;
def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input),
(replaceWithValue $input),
@ -451,23 +455,25 @@ def PReluAlphaRankCheck : Constraint<
// PReLU pattern from Keras:
// f(x) = Relu(x) + (-alpha * Relu(-x))
def : Pat<(TFL_AddOp
(TFL_ReluOp:$relu_out $input1),
(TFL_MulOp:$mul_out
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
$neg_alpha,
TFL_AF_None),
TFL_AF_None),
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
[(EqualOperands $input1, $input2),
(PReluAlphaRankCheck $neg_alpha, $input1),
(HasOneUse $relu_out),
(HasOneUse $mul_out),
(HasOneUse $input_neg_out)]>;
def MatchPRelu : Pat<
(TFL_AddOp
(TFL_ReluOp:$relu_out $input1),
(TFL_MulOp:$mul_out
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
$neg_alpha,
TFL_AF_None),
TFL_AF_None),
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
[(EqualOperands $input1, $input2),
(PReluAlphaRankCheck $neg_alpha, $input1),
(HasOneUse $relu_out),
(HasOneUse $mul_out),
(HasOneUse $input_neg_out)]>;
// The constant folding in this pass might produce constant in the tf dialect.
// This rule is to legalize these constant to the tfl dialect.
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
def LegalizeConstOp : Pat<
(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
// Reorders adds to allow constant folding.
// Add --> Add $input, $constantA
@ -476,13 +482,14 @@ def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
// Add --> $input
// \--> Add ($constantA, $constantB)
foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
def : Pat<(TFL_AddOp
(TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
(ConstantOp $b), ActFun),
(TFL_AddOp $input,
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
ActFun),
[(HasOneUse $first_output)]>;
def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat<
(TFL_AddOp
(TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
(ConstantOp $b), ActFun),
(TFL_AddOp $input,
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
ActFun),
[(HasOneUse $first_output)]>;
}
// We can eliminate Relu from Relu(SquaredDifference(x, y)),

View File

@ -143,8 +143,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
type = FunctionType::get(types, result_types, &getContext());
}
auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type,
ArrayRef<NamedAttribute>{});
auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type);
outlined_func.getBody().takeBody(region);
Region& func_region = outlined_func.getBody();

View File

@ -301,8 +301,8 @@ bool IslandOp::WrapsSingleOp() {
namespace {
LogicalResult Verify(IslandOp island) {
if (island.GetBody().empty())
return island.emitOpError() << "expects a non-empty body";
if (!island.GetBody().args_empty())
return island.emitOpError() << "expects body without any arguments";
Operation &yield = island.GetBody().back();
if (!isa<YieldOp>(yield))

View File

@ -232,6 +232,20 @@ func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
// -----
// Check that an island body doesn't have any block arguments.
func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
tf_executor.graph {
"tf_executor.island"() ({
// expected-error@-1 {{expects body without any arguments}}
^entry(%arg: tensor<2xi32>):
tf_executor.yield
}) : () -> (!tf_executor.control)
}
return
}
// -----
// Check that an island body can't be empty.
func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
tf_executor.graph {

View File

@ -9,16 +9,15 @@
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>) -> () {
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
// CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {_xla_outside_compilation = "0", value = dense<"inference"> : tensor<!tf.string>} : () -> tensor<!tf.string>
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
%2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
return
}
@ -34,20 +33,19 @@ func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
func @check_enqueue_ops_update_for_training(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>) -> () {
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
%3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
"tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> ()
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
// CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {_xla_outside_compilation = "0", value = dense<"train"> : tensor<!tf.string>} : () -> tensor<!tf.string>
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
%4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
return
}

View File

@ -77,8 +77,7 @@ Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder,
ArrayRef<Type>{RankedTensorType::get(
{static_cast<int64_t>(buffer_type.getShape().size())},
getElementTypeOrSelf(index.getType()))},
ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)});
}
Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc,
@ -95,15 +94,14 @@ Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc,
auto slice = builder.create<TF::SliceOp>(
loc, ArrayRef<Type>{slice_type},
ArrayRef<Value>{buffer, GetIndicesForElement(index, buffer, builder, loc),
size_const},
ArrayRef<NamedAttribute>{});
size_const});
if (keep_slice_shape) return slice;
auto element_type = RankedTensorType::get(buffer_type.getShape().drop_front(),
buffer_type.getElementType());
auto reshape = builder.create<TF::ReshapeOp>(
loc, ArrayRef<Type>{element_type},
ArrayRef<Value>{slice, GetR1Const(element_type.getShape(), builder, loc)},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{slice,
GetR1Const(element_type.getShape(), builder, loc)});
return reshape.output();
}
@ -120,15 +118,13 @@ Value SetElement(Value index, Value buffer, Value element, OpBuilder builder,
if (element.getType() != slice_type) {
update_slice = builder.create<TF::ReshapeOp>(
loc, ArrayRef<Type>{slice_type},
ArrayRef<Value>{element, GetR1Const(slice_shape, builder, loc)},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{element, GetR1Const(slice_shape, builder, loc)});
}
return builder
.create<TF::XlaDynamicUpdateSliceOp>(
loc, ArrayRef<Type>{buffer.getType()},
ArrayRef<Value>{buffer, update_slice,
GetIndicesForElement(index, buffer, builder, loc)},
ArrayRef<NamedAttribute>{})
GetIndicesForElement(index, buffer, builder, loc)})
.output();
}
@ -140,8 +136,7 @@ Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc) {
auto size_type = GetSizeType(builder);
return builder.create<TF::ReshapeOp>(
loc, ArrayRef<Type>{size_type},
ArrayRef<Value>{scalar, GetR1Const(size_type.getShape(), builder, loc)},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{scalar, GetR1Const(size_type.getShape(), builder, loc)});
}
LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
@ -171,13 +166,12 @@ LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
if (getElementTypeOrSelf(zero.getType()) != element_dtype) {
zero = builder.create<TF::CastOp>(
op->getLoc(), ArrayRef<Type>{RankedTensorType::get({}, element_dtype)},
ArrayRef<Value>{zero}, ArrayRef<NamedAttribute>{});
ArrayRef<Value>{zero});
}
auto buffer_type = RankedTensorType::get(buffer_shape, element_dtype);
auto broadcast = builder.create<TF::BroadcastToOp>(
op->getLoc(), ArrayRef<Type>{buffer_type},
ArrayRef<Value>{zero, GetR1Const(buffer_shape, builder, op->getLoc())},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{zero, GetR1Const(buffer_shape, builder, op->getLoc())});
*buffer = broadcast.output();
return success();
}
@ -241,27 +235,24 @@ Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) {
ArrayRef<Type>{getElementTypeOrSelf(local_var.getType())
.cast<TF::ResourceType>()
.getSubtypes()[0]},
ArrayRef<Value>{local_var}, ArrayRef<NamedAttribute>{})
ArrayRef<Value>{local_var})
.value();
}
// Creates an AssignVariableOp on a local variable.
TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value,
OpBuilder builder, Location loc) {
return builder.create<TF::AssignVariableOp>(loc, ArrayRef<Type>{},
ArrayRef<Value>{local_var, value},
ArrayRef<NamedAttribute>{});
return builder.create<TF::AssignVariableOp>(
loc, ArrayRef<Type>{}, ArrayRef<Value>{local_var, value});
}
Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc) {
if (getElementTypeOrSelf(a.getType()) == builder.getI1Type()) {
return builder.create<TF::LogicalOrOp>(loc, ArrayRef<Type>{a.getType()},
ArrayRef<Value>{a, b},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{a, b});
}
return builder.create<TF::AddV2Op>(loc, ArrayRef<Type>{a.getType()},
ArrayRef<Value>{a, b},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{a, b});
}
namespace {
@ -303,15 +294,13 @@ Value GatherElements(Value indices, Value buffer, OpBuilder builder,
return builder.create<TF::SliceOp>(
loc, ArrayRef<Type>{slice_type},
ArrayRef<Value>{buffer, GetR1Const(slice_starts, builder, loc),
GetR1Const(result_shape, builder, loc)},
ArrayRef<NamedAttribute>{});
GetR1Const(result_shape, builder, loc)});
}
auto result_type =
RankedTensorType::get(result_shape, buffer_type.getElementType());
return builder.create<TF::GatherV2Op>(
loc, ArrayRef<Type>{result_type},
ArrayRef<Value>{buffer, indices, CreateScalarConst(0, builder, loc)},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{buffer, indices, CreateScalarConst(0, builder, loc)});
}
Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
@ -334,8 +323,7 @@ Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
auto index = builder.create<TF::SliceOp>(
loc, ArrayRef<Type>{GetSizeType(builder)},
ArrayRef<Value>{indices, GetR1Const({i}, builder, loc),
GetR1Const({1}, builder, loc)},
ArrayRef<NamedAttribute>{});
GetR1Const({1}, builder, loc)});
auto old_slice =
GetElement(index, buffer, builder, loc, /*keep_slice_shape=*/true);
starts_in_update[0] = i;
@ -344,8 +332,7 @@ Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
builder
.create<TF::SliceOp>(
loc, ArrayRef<Type>{old_slice.getType()},
ArrayRef<Value>{updates, update_slice_starts, slice_sizes},
ArrayRef<NamedAttribute>{})
ArrayRef<Value>{updates, update_slice_starts, slice_sizes})
.output();
slice = AccumulateBuffers(old_slice, slice, builder, loc);
buffer = SetElement(index, buffer, slice, builder, loc);

View File

@ -185,8 +185,8 @@ IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
Operation* old_island = insert_position == kParentIsland ? parent : child;
OpBuilder builder(old_island);
auto new_island = builder.create<IslandOp>(
old_island->getLoc(), result_types, operands, ArrayRef<NamedAttribute>{});
auto new_island =
builder.create<IslandOp>(old_island->getLoc(), result_types, operands);
new_island.body().push_back(new Block);
return new_island;
}

View File

@ -105,8 +105,8 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
// Create the outlined function
SmallString<32> name = kOutlinedFuncPrefix;
name += llvm::Twine(prefix_id++).str();
auto outlined_func = OpBuilder(ctx).create<FuncOp>(
island_op.getLoc(), name, func_type, ArrayRef<NamedAttribute>());
auto outlined_func =
OpBuilder(ctx).create<FuncOp>(island_op.getLoc(), name, func_type);
outlined_symbol_table.insert(outlined_func);
// We will "steal" the body of the island and replace it with a call to the

View File

@ -190,7 +190,7 @@ tf_executor::IslandOp CreateOutputBarrierIsland(
builder->setInsertionPoint(island_op);
auto island_output_sink = builder->create<tf_executor::IslandOp>(
island_op.getLoc(), llvm::to_vector<8>(island_op.getResultTypes()),
island_operands, llvm::ArrayRef<NamedAttribute>{});
island_operands);
island_output_sink.body().push_back(new Block);
return island_output_sink;
}

View File

@ -160,7 +160,7 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
builder.setInsertionPoint(user);
ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
user->getLoc(), ArrayRef<Type>{tensor_type},
ArrayRef<Value>{var_handle_op}, ArrayRef<NamedAttribute>{});
ArrayRef<Value>{var_handle_op});
user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
user->erase();
}

View File

@ -124,8 +124,7 @@ void ExtractSingleBlockRegion(Region& region, StringRef name,
auto type = FunctionType::get(input_types, return_types, region.getContext());
// Create new function and extract region body into the function.
auto outlined_func =
builder.create<FuncOp>(loc, name, type, ArrayRef<NamedAttribute>{});
auto outlined_func = builder.create<FuncOp>(loc, name, type);
Region& func_region = outlined_func.getBody();
func_region.takeBody(region);
Block& first_block = func_region.front();

View File

@ -558,15 +558,13 @@ void AddLoadsStoresOutsideControlFlowOp(
auto operand = caller->getOperand(index);
builder.setInsertionPoint(caller);
new_operands[index] = builder.create<TF::ReadVariableOp>(
caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand},
ArrayRef<NamedAttribute>{});
caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand});
caller->setOperand(index, new_operands[index]);
if (updated_index < 0) continue;
builder.setInsertionPointAfter(caller);
builder.create<TF::AssignVariableOp>(
caller->getLoc(), ArrayRef<Type>{},
ArrayRef<Value>{operand, caller->getResult(updated_index)},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{operand, caller->getResult(updated_index)});
}
}

View File

@ -409,11 +409,9 @@ LogicalResult HandleStackV2Op(
ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
stack.getContext()));
auto local_var = builder.create<TF::MlirLocalVarOp>(
stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
auto local_size_var = builder.create<TF::MlirLocalVarOp>(
stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{});
// Zero-initialize the local vars.
cutil::WriteLocalVariable(local_size_var,
cutil::GetR1Const({0LL}, builder, stack.getLoc()),
@ -446,8 +444,7 @@ LogicalResult HandleStackPushV2Op(
cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
index = builder.create<TF::AddV2Op>(
push.getLoc(), ArrayRef<Type>{index.getType()},
ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())});
cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
push.erase();
return success();
@ -467,8 +464,7 @@ LogicalResult HandleStackPopV2Op(
auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
auto new_size = builder.create<TF::SubOp>(
pop.getLoc(), ArrayRef<Type>{size.getType()},
ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())});
auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc());
pop.replaceAllUsesWith(pop_val);
// Update the size.

View File

@ -166,8 +166,7 @@ LogicalResult HandleTensorArrayV3Op(
ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
ta.getContext()));
auto local_var = builder.create<TF::MlirLocalVarOp>(
ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc());
ta.handle().replaceAllUsesWith(local_var);
// The flow output is just a way for the front end to enforce ordering among
@ -227,8 +226,7 @@ LogicalResult HandleTensorArrayWriteV3Op(
elem = builder.create<TF::ReshapeOp>(
write.getLoc(), ArrayRef<Type>{slice_type},
ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
write.getLoc())},
ArrayRef<NamedAttribute>{});
write.getLoc())});
elem =
cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc());
}
@ -261,8 +259,7 @@ LogicalResult HandleTensorArrayConcatV3Op(
ArrayRef<Type>{
RankedTensorType::get(shape, buffer_type.getElementType())},
ArrayRef<Value>{buffer,
cutil::GetR1Const(shape, builder, concat.getLoc())},
ArrayRef<NamedAttribute>{});
cutil::GetR1Const(shape, builder, concat.getLoc())});
concat.value().replaceAllUsesWith(buffer);
// Create the lengths as a list of the same value (element size).
@ -302,8 +299,7 @@ LogicalResult HandleTensorArraySplitV3Op(
buffer_shape, elem_type.getElementType())},
ArrayRef<Value>{split.value(),
cutil::GetR1Const(buffer_shape, builder,
split.getLoc())},
ArrayRef<NamedAttribute>{})
split.getLoc())})
.output();
// Accumulate with the old buffer.
auto old_buffer =
@ -339,8 +335,7 @@ LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
Operation* op, Value* var) {
OpBuilder builder(op);
*var = builder.create<TF::MlirLocalVarOp>(
op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{});
Value buffer;
auto buffer_type = getElementTypeOrSelf(local_var_type)
.cast<TF::ResourceType>()

View File

@ -438,7 +438,7 @@ LogicalResult HandleTensorListFromTensorOp(
OpBuilder builder(list);
Value buffer = builder.create<TF::IdentityOp>(
list.getLoc(), ArrayRef<Type>{list.tensor().getType()},
ArrayRef<Value>{list.tensor()}, ArrayRef<NamedAttribute>{});
ArrayRef<Value>{list.tensor()});
auto type = buffer.getType().cast<TensorType>();
if (!type.hasStaticShape()) {
return list.emitOpError("TensorListFromTensorOp input has unknown shape.");
@ -468,8 +468,7 @@ LogicalResult HandleTensorListPushBackOp(
cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc());
auto new_size = builder.create<TF::AddV2Op>(
push.getLoc(), ArrayRef<Type>{size.getType()},
ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())});
push.output_handle().replaceAllUsesWith(new_buffer);
(*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
push.erase();
@ -491,12 +490,10 @@ LogicalResult HandleTensorListPopBackOp(
auto size = it->getSecond().size;
OpBuilder builder(pop);
auto new_buffer = builder.create<TF::IdentityOp>(
pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer},
ArrayRef<NamedAttribute>{});
pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer});
auto new_size = builder.create<TF::SubOp>(
pop.getLoc(), ArrayRef<Type>{size.getType()},
ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())},
ArrayRef<NamedAttribute>{});
ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())});
auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc());
pop.output_handle().replaceAllUsesWith(new_buffer);
pop.tensor().replaceAllUsesWith(element);
@ -567,8 +564,7 @@ LogicalResult HandleTensorListLengthOp(
ArrayRef<Type>{RankedTensorType::get(
{}, getElementTypeOrSelf(current_size.getType()))},
ArrayRef<Value>{current_size,
cutil::GetR1Const({}, builder, length.getLoc())},
ArrayRef<NamedAttribute>{});
cutil::GetR1Const({}, builder, length.getLoc())});
length.length().replaceAllUsesWith(reshape);
}
length.erase();

View File

@ -154,8 +154,7 @@ TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch,
Value input, OpBuilder* builder) {
return builder->create<TF::TPUCopyWithLayoutOp>(
execute_launch.getLoc(), llvm::ArrayRef<Type>{input.getType()},
llvm::ArrayRef<Value>{input, get_layout.layout()},
llvm::ArrayRef<NamedAttribute>{});
llvm::ArrayRef<Value>{input, get_layout.layout()});
}
// Performs transformation for a non-replicated input.

View File

@ -206,8 +206,7 @@ TF::_HostComputeMlirOp CreateHostCompute(
device_output_types.push_back(output.getType());
SetHostComputeInsertion(builder, cluster_ops, inputs);
auto host_compute = builder->create<TF::_HostComputeMlirOp>(
tpu_cluster.getLoc(), device_output_types, inputs.getArrayRef(),
llvm::ArrayRef<NamedAttribute>{});
tpu_cluster.getLoc(), device_output_types, inputs.getArrayRef());
host_compute.setAttr(kAncestorsAttr, builder->getArrayAttr({}));
host_compute.setAttr(kShapesAttr, builder->getArrayAttr({}));
host_compute.setAttr(kKeyAttr, builder->getStringAttr(communication_key));

View File

@ -473,9 +473,8 @@ LogicalResult BuildExecuteOp(
if (failed(result)) return failure();
// TPUExecute has same output types as cluster_func.
*execute_op = builder->create<TF::TPUExecuteOp>(
cluster_func.getLoc(), output_types, inputs,
llvm::ArrayRef<NamedAttribute>{});
*execute_op = builder->create<TF::TPUExecuteOp>(cluster_func.getLoc(),
output_types, inputs);
return success();
}

View File

@ -13,24 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
struct TPUUpdateEmbeddingEnqueueOpInputs
@ -86,7 +91,8 @@ LogicalResult FindTPUEmbeddingOps(
LogicalResult UpdateEmbeddingEnqueueOpInput(
const llvm::StringMap<Operation*>& enqueue_op_map,
const llvm::StringMap<Operation*>& recv_activation_op_map,
const llvm::StringMap<Operation*>& send_gradient_op_map) {
const llvm::StringMap<Operation*>& send_gradient_op_map,
OpBuilder* builder) {
for (const auto& it : enqueue_op_map) {
const auto& embedding_attr = it.getKey();
Operation* embedding_op = it.second;
@ -96,21 +102,36 @@ LogicalResult UpdateEmbeddingEnqueueOpInput(
<< TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
// TPU Embedding enqueue ops take different inputs depending on whether
// graph is in training mode or in eval/prediction mode. The inputs to the
// enqueue ops are present/listed as operands to SelectV2 op. Then branch
// operand of the SelectV2 op represents input to take during training
// and else branch operand represents input to take during
// prediction/evaluation. If SendTPUEmbeddingGradients op exists in the
// graph, then graph is in training mode, so correctly forward the input
// of SelectV2 op as operand to the TPU embedding enqueue op.
// graph is in training mode or in eval/prediction mode. During training,
// the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
// evaluation or prediction, mode must be set to `inference`.
// If SendTPUEmbeddingGradients op exists in the graph, then graph is
// in training mode, so create a const op with value `train` use the
// output value of the constant as an operand to the TPU embedding
// enqueue op.
bool is_training = send_gradient_op_map.count(embedding_attr);
for (auto enqueue_operand : embedding_op->getOperands()) {
if (auto select = llvm::dyn_cast_or_null<TF::SelectV2Op>(
enqueue_operand.getDefiningOp())) {
enqueue_operand.replaceAllUsesWith(is_training ? select.t()
: select.e());
}
}
// The last operand of TPUEmbeddingEnqueue ops is the mode which
// represents whether graph is in training mode or in evaluation mode.
auto& mode_enqueue_operand =
embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
llvm::SmallVector<StringRef, 1> mode_string_value;
mode_string_value.emplace_back(is_training ? "train" : "inference");
builder->setInsertionPoint(embedding_op);
auto enqueue_mode = builder->create<TF::ConstOp>(
embedding_op->getLoc(),
DenseStringElementsAttr::get(
RankedTensorType::get({}, builder->getType<TF::StringType>()),
mode_string_value));
auto outside_compilation_attr =
embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
if (outside_compilation_attr)
enqueue_mode.setAttr(kXlaOutsideCompilationAttr,
outside_compilation_attr);
mode_enqueue_operand.set(enqueue_mode);
}
return success();
@ -140,8 +161,9 @@ void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() {
return signalPassFailure();
}
if (failed(UpdateEmbeddingEnqueueOpInput(
enqueue_op_map, recv_activation_op_map, send_gradient_op_map)))
if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
recv_activation_op_map,
send_gradient_op_map, &builder)))
return signalPassFailure();
}

View File

@ -521,8 +521,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
replicate.GetNumReplicatedBlockArguments() - 1));
builder.setInsertionPoint(execute_launch);
auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands,
llvm::ArrayRef<NamedAttribute>{});
execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands);
WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
execute_launch.device());
@ -579,8 +578,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
default_state_key.getResult());
// Unformat op.
auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands,
llvm::ArrayRef<NamedAttribute>{});
while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands);
WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
execute_launch.device());
builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});

View File

@ -21,10 +21,6 @@ package_group(
includes = [
"//tensorflow/compiler/tf2xla:internal",
],
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/compiler/tests/...",
],
)
package_group(
@ -34,7 +30,6 @@ package_group(
],
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/compiler/tests/...",
"//platforms/xla/tests/neural_nets",
],
)

View File

@ -311,7 +311,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
if 'GPU' in self.device:
# TODO(b/32333178)
self.skipTest('Current implementation of RandomStandardNormal kernel '
'is very slow on GPU, and has been blacklisted.')
'is very slow on GPU, and has been denylisted.')
with self.test_scope():
data_format = 'channels_last'
conv = convolutional.Conv2D(

View File

@ -4946,7 +4946,18 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
node_def.name());
}
nvinfer1::ITensor* tensor = inputs.at(0).tensor();
if (!params->use_implicit_batch && tensor->getDimensions().d[1] == -1) {
// This check is to make sure that channel dimension is known during
// conversion.
//
// We check this only in explicit batch mode and reject an op with unknown
// channel dimension during segmentation. In implicit batch mode we have
// known shapes during conversion even though the shapes may not be known
// during segmentation (see the actual argument for input_shapes when
// ConvertGraphDefToEngine is called from TRTEngineOp::BuildEngine).
return errors::InvalidArgument("Channel dimension must be static, at ",
node_def.name());
}
// Check parameter types
auto parameter_type = inputs.at(1).weights().TrtDType();
if ((parameter_type != nvinfer1::DataType::kFLOAT) &&

View File

@ -2011,6 +2011,142 @@ TEST_F(OpConverterTest, ConvertConst) {
TestConvertConst<DT_UINT64, uint64, int32>(this);
}
template <typename T>
NodeDef CreateFusedBatchNormOp(DataType tf_type, std::string data_format,
bool is_training, float epsilon) {
Scope s = Scope::NewRootScope();
auto x = ops::Placeholder(s.WithOpName("x"), tf_type);
auto scale = ops::Placeholder(s.WithOpName("scale"), tf_type);
auto offset = ops::Placeholder(s.WithOpName("offset"), tf_type);
auto mean = ops::Placeholder(s.WithOpName("mean"), tf_type);
auto variance = ops::Placeholder(s.WithOpName("variance"), tf_type);
typename T::Attrs attrs;
attrs.data_format_ = data_format;
attrs.is_training_ = is_training;
if (epsilon > 0) {
attrs.epsilon_ = epsilon;
} else {
EXPECT_GE(epsilon, 0);
}
return T(s.WithOpName("my_batchnorm"), x, scale, offset, mean, variance,
attrs)
.operation.node()
->def();
}
TEST_P(OpConverterTest1, ConvertFusedBatchNorm) {
using OpFunc = std::function<NodeDef(DataType, std::string, bool, float)>;
std::vector<OpFunc> get_node_def_vec{
CreateFusedBatchNormOp<ops::FusedBatchNorm>,
CreateFusedBatchNormOp<ops::FusedBatchNormV2>,
CreateFusedBatchNormOp<ops::FusedBatchNormV3>};
struct TestParam {
std::string data_format;
int tensor_input_idx; // Index of an input that will be provided as tensor.
bool is_training;
float epsilon;
Status conversion_status;
bool keep_channel_unknown;
};
struct NodeInput {
std::string name;
std::vector<int> dims;
std::vector<float> val;
};
std::vector<NodeInput> node_input{
{"x", {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}},
{"scale", {3}, {7, 8, 9}},
{"offset", {3}, {10, 20, 30}},
{"mean", {3}, {1, 2, 3}},
{"variance", {3}, {4, 5, 6}}};
std::vector<float> expected_output{10.0, 13.495633, 23.574135, 27.148273,
37.342354, 41.013527, 30.9738, 34.469433,
45.018955, 48.59309, 59.369415, 63.04059};
for (auto get_node_def : get_node_def_vec) {
NodeDef tmp_node_def = get_node_def(tf_type, "NCHW", true, 0);
std::string op_name = tmp_node_def.op();
std::vector<TestParam> test_param{
{"NHWC", 0, false, 0,
errors::Unimplemented(StrCat(
op_name, " only supports data_format=NCHW, at my_batchnorm"))},
{"NCHW", 0, true, 0,
errors::Unimplemented(StrCat(
op_name, " only supports is_training=false, at my_batchnorm"))},
{"NCHW", 1, false, 0,
errors::Unimplemented(StrCat("The input \"scale\" for ", op_name,
" must be a constant, at my_batchnorm"))},
{"NCHW", 2, false, 0,
errors::Unimplemented(StrCat("The input \"offset\" for ", op_name,
" must be a constant, at my_batchnorm"))},
{"NCHW", 3, false, 0,
errors::Unimplemented(StrCat("The input \"mean\" for ", op_name,
" must be a constant, at my_batchnorm"))},
{"NCHW", 4, false, 0,
errors::Unimplemented(StrCat("The input \"variance\" for ", op_name,
" must be a constant, at my_batchnorm"))},
{"NCHW", 0, false, 0.01}}; // The last one is the only test that runs.
if (trt_mode == TrtTestMode::kDynamicShape) {
test_param.push_back(
{"NCHW", 0, false, 0.01,
errors::InvalidArgument(
"Channel dimension must be static, at my_batchnorm"),
true});
}
for (auto p : test_param) {
Reset();
NodeDef node_def =
get_node_def(tf_type, p.data_format, p.is_training, p.epsilon);
for (int i = 0; i < node_input.size(); i++) {
if (i == 0 || i == p.tensor_input_idx) {
// The first input (x) is always added as a tensor, and it hase shape
// NCHW. The other inputs are per channel values (1D, size C).
//
// In implicit batch mode, it is not possible to add any of the 1D
// inputs as a tensor: the first dim is always treated as batch dim in
// implicit batch mode, and that has to agree for all tensors. We have
// two input tensors with shapes NCHW and C and in general N != C.
// The converter already picked up N from the fist input, and reports
// an error when we try to add any other tensors with not matching
// first dim.
//
// This restriction does not apply in explicit batch mode: the tensors
// can have different first dim. The converter still expects that only
// the first arg is a tensor. TODO(tfeher) Check if one can relax this
// restriction.
Status expected_status =
(i != 0 && trt_mode == TrtTestMode::kImplicitBatch)
? errors::InvalidArgument(
StrCat("Batch size doesn't match for tensor ",
node_input[i].name,
": Provided batch size does not match "
"converter batch size: 3 vs 2"))
: Status::OK();
std::vector<int> partial_input_shape;
if (i == 0 && trt_mode == TrtTestMode::kDynamicShape &&
!p.keep_channel_unknown) {
// keep channel dim static (known)
partial_input_shape.resize(4, -1);
partial_input_shape[1] = node_input[i].dims[1];
}
AddTestTensor(node_input[i].name, node_input[i].dims, tf_type,
node_input[i].val, partial_input_shape,
expected_status);
} else {
AddTestWeights(node_input[i].name, node_input[i].dims,
node_input[i].val, tf_type);
}
}
TestOpConverter("my_batchnorm", node_def, node_input[0].dims,
p.conversion_status, Status::OK(),
ArrayFloatNear(expected_output));
}
}
} // namespace convert
TEST_P(OpConverterTest1, ConvertTranspose) {
// Get the NodeDef for Transpose.
Scope s = Scope::NewRootScope();

View File

@ -711,15 +711,15 @@ Status SegmentGraph(const Graph* tf_graph,
std::unordered_set<string> unsupported_ops;
int num_unsupported_ops = 0;
// Getting the operations blacklisted for conversion
string tftrt_op_blacklist_str;
// Getting the operations denylisted for conversion
string tftrt_op_denylist_str;
TF_CHECK_OK(
ReadStringFromEnvVar("TF_TRT_OP_BLACKLIST", "", &tftrt_op_blacklist_str));
ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", "", &tftrt_op_denylist_str));
auto tftrt_op_blacklist = gtl::FlatSet<string>{}; // non-absl ok
auto tftrt_op_denylist = gtl::FlatSet<string>{}; // non-absl ok
for (const auto& x : str_util::Split(tftrt_op_blacklist_str, ",")) {
tftrt_op_blacklist.insert(x);
for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) {
tftrt_op_denylist.insert(x);
}
// Parsing each node of the graph
@ -761,13 +761,13 @@ Status SegmentGraph(const Graph* tf_graph,
const Status status = candidate_fn(node->tf_node());
if (!status.ok()) {
exclude_node(status.error_message());
} else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) {
} else if (tftrt_op_denylist.count(node->tf_node()->type_string())) {
// WARNING verbosity since the user explicitly requests this behavior.
LOG_WARNING_WITH_PREFIX
<< "Blacklisted as TF-TRT candidate, "
<< "Denylisted as TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name() << ")";
exclude_node("Blacklisted with the env var TF_TRT_OP_BLACKLIST");
exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
} else {
VLOG(2) << "Accepted as a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "

View File

@ -535,10 +535,10 @@ static void AllocateFlags() {
flag_values->xla_gpu_force_conv_nchw(),
"For cuDNN convolutions, always NCHW layouts."));
flag_objects->push_back(tensorflow::Flag(
"xla_gpu_algorithm_blacklist_path",
string_setter_for(&DebugOptions::set_xla_gpu_algorithm_blacklist_path),
flag_values->xla_gpu_algorithm_blacklist_path(),
"An AlgorithmBlacklist text proto file as a blacklist of convolutions to "
"xla_gpu_algorithm_denylist_path",
string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path),
flag_values->xla_gpu_algorithm_denylist_path(),
"An AlgorithmDenylist text proto file as a denylist of convolutions to "
"avoid to use."));
flag_objects->push_back(tensorflow::Flag(
"xla_gpu_deterministic_reductions",

View File

@ -688,9 +688,7 @@ StatusOr<bool> RewriteDynamicConcat(
dynamic_size));
}
}
for (HloInstruction* user : prev_users) {
TF_RETURN_IF_ERROR(concat->ReplaceUseWith(user, rewritten_concat));
}
TF_RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat));
TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
concat, rewritten_concat, {}));
return true;

View File

@ -83,8 +83,8 @@ class DynamicPadderTest : public HloTestBase {
return module;
}
StatusOr<bool> RunPadder() {
DynamicPadder padder(/*slice_dynamic_output=*/true,
StatusOr<bool> RunPadder(bool slice_dynamic_output = false) {
DynamicPadder padder(/*slice_dynamic_output=*/slice_dynamic_output,
CustomCallDynamicDimensionInference,
OpHasDynamismSupport);
return padder.Run(module_.get());
@ -162,7 +162,7 @@ ENTRY main {
module_ = GetHloModule(hlo_text);
TF_ASSERT_OK(RunPadder().status());
TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
// After rewrite, we should have :
//
// param
@ -218,7 +218,7 @@ ENTRY main {
module_ = GetHloModule(hlo_text);
TF_ASSERT_OK(RunPadder().status());
TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
// After rewrite, we should have :
//
// param
@ -654,26 +654,16 @@ XLA_TEST_F(ExecutionTest, DynamicConcat) {
const string hlo_text = R"(
HloModule DynamicConcat
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(lhs, rhs)
}
ENTRY main {
param_0 = s32[3] parameter(0)
param_1 = s32[3] parameter(1)
param_2 = s32[3] parameter(2)
size = s32[] constant(2)
param_padded_0 = s32[3] set-dimension-size(param_0, size), dimensions={0}
param_padded_2 = s32[3] set-dimension-size(param_2, size), dimensions={0}
%concatenate = s32[9]
concatenate(s32[3] param_padded_0, s32[3] param_1, s32[3] param_padded_2),
param_padded_0 = s32[<=3] set-dimension-size(param_0, size), dimensions={0}
param_padded_2 = s32[<=3] set-dimension-size(param_2, size), dimensions={0}
ROOT %concatenate = s32[9]
concatenate(s32[<=3] param_padded_0, s32[<=3] param_1, s32[<=3] param_padded_2),
dimensions={0}
init = s32[] constant(0)
ROOT reduce = s32[] reduce(concatenate, init),
dimensions={0},
to_apply=update_s32
}
)";
@ -686,10 +676,10 @@ ENTRY main {
LiteralUtil::CreateR1<int32>({6, 7, -1}); // Dynamic operand.
auto module = GetHloModule(hlo_text);
Literal result =
PadAndExecute(std::move(module), {&operand_0, &operand_1, &operand_2});
Literal expected = LiteralUtil::CreateR0<int32>(28);
Literal result = PadAndExecute(std::move(module),
{&operand_0, &operand_1, &operand_2}, false);
result.SetDynamicSize(0, 7);
Literal expected = LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6, 7});
EXPECT_EQ(result, expected);
}

View File

@ -170,7 +170,6 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
@ -1676,7 +1675,7 @@ cc_library(
tf_cc_test(
name = "hlo_algorithm_blacklist_test",
srcs = ["hlo_algorithm_blacklist_test.cc"],
data = ["data/hlo_algorithm_blacklist.pbtxt"],
data = ["data/hlo_algorithm_denylist.pbtxt"],
tags = ["no_pip"],
deps = [
":hlo_algorithm_blacklist",

View File

@ -15,19 +15,19 @@ message ConvInstructionLog {
repeated uint64 operand_addresses = 4;
}
message BlacklistedAlgorithm {
message DenylistedAlgorithm {
int64 id = 1;
bool tensor_ops = 2;
}
message AlgorithmBlacklistEntry {
message AlgorithmDenylistEntry {
string hlo = 1;
tensorflow.ComputeCapability cc = 2;
tensorflow.CudnnVersion cudnn_version = 3;
string blas_version = 5;
repeated BlacklistedAlgorithm algos = 4;
repeated DenylistedAlgorithm algos = 4;
}
message AlgorithmBlacklist {
repeated AlgorithmBlacklistEntry entries = 1;
message AlgorithmDenylist {
repeated AlgorithmDenylistEntry entries = 1;
}

View File

@ -438,10 +438,9 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
(void)blas->GetVersion(&blas_version);
}
absl::Span<const AlgorithmDesc> blacklisted_algos =
GetBlacklistedConvAlgorithms(GetComputeCapability(stream_exec_),
GetCudnnVersion(stream_exec_), blas_version,
canonical_hlo);
absl::Span<const AlgorithmDesc> disabled_algos = GetDisabledConvAlgorithms(
GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_),
blas_version, canonical_hlo);
for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
XLA_SCOPED_LOGGING_TIMER_LEVEL(
@ -449,7 +448,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
AlgorithmToString(alg)),
2);
if (absl::c_linear_search(blacklisted_algos, alg)) {
if (absl::c_linear_search(disabled_algos, alg)) {
LOG(INFO) << "Omitted potentially buggy algorithm "
<< AlgorithmToString(alg) << " for conv " << instr->ToString();
continue;
@ -503,7 +502,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
if (!input_output_allocator_redzone_clear ||
!scratch_allocator_redzone_clear) {
AlgorithmBlacklist proto;
AlgorithmDenylist proto;
auto entry = proto.add_entries();
entry->set_hlo(canonical_hlo);
*entry->mutable_cc() = GetComputeCapability(stream_exec_);
@ -513,13 +512,12 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
algo->set_id(alg.algo_id());
algo->set_tensor_ops(alg.tensor_ops_enabled());
LOG(ERROR)
<< "To blacklist this algorithm for this convolution, "
"copy-paste the following "
"proto to the blacklist file pointed by XLA_FLAGS "
"--xla_gpu_algorithm_blacklist_path="
<< GetDebugOptionsFromFlags().xla_gpu_algorithm_blacklist_path()
<< " : " << proto.ShortDebugString();
LOG(ERROR) << "To denylist this algorithm for this convolution, "
"copy-paste the following "
"proto to the denylist file pointed by XLA_FLAGS "
"--xla_gpu_algorithm_denylist_path="
<< GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
<< " : " << proto.ShortDebugString();
continue;
}

View File

@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
namespace gpu {
constexpr char kDefaultBlacklist[] = R"pb(
constexpr char kDefaultDenylist[] = R"pb(
entries {
hlo: "(f32[4,32,32,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[4,32,32,32]{2,1,3,0}, f32[5,5,32,32]{1,0,2,3}), window={size=5x5 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\""
cc { major: 7 }
@ -41,28 +41,26 @@ constexpr char kDefaultBlacklist[] = R"pb(
}
)pb";
absl::Span<const stream_executor::dnn::AlgorithmDesc>
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
tensorflow::CudnnVersion cudnn_version,
const std::string& blas_version,
const std::string& hlo) {
absl::Span<const stream_executor::dnn::AlgorithmDesc> GetDisabledConvAlgorithms(
tensorflow::ComputeCapability cc, tensorflow::CudnnVersion cudnn_version,
const std::string& blas_version, const std::string& hlo) {
// Key is the tuple of canonicalized hlo, compute capability major/minor,
// cudnn version major/minor/patch, blas version.
using MapType = absl::flat_hash_map<
std::tuple<std::string, int, int, int, int, int, std::string>,
std::vector<stream_executor::dnn::AlgorithmDesc>>;
static MapType* blacklist = [] {
static MapType* denylist = [] {
MapType* list = new MapType();
AlgorithmBlacklist proto;
AlgorithmDenylist proto;
std::string file_path =
GetDebugOptionsFromFlags().xla_gpu_algorithm_blacklist_path();
GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path();
if (!file_path.empty()) {
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(),
file_path, &proto));
} else {
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
std::string(kDefaultBlacklist), &proto));
std::string(kDefaultDenylist), &proto));
}
for (const auto& entry : proto.entries()) {
for (const auto& algo : entry.algos()) {
@ -77,10 +75,10 @@ GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
return list;
}();
auto iter = blacklist->find(std::make_tuple(
auto iter = denylist->find(std::make_tuple(
hlo, cc.major(), cc.minor(), cudnn_version.major(), cudnn_version.minor(),
cudnn_version.patch(), std::string(blas_version)));
if (iter != blacklist->end()) {
if (iter != denylist->end()) {
return iter->second;
}
return {};

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_
#include <vector>
@ -24,13 +24,11 @@ limitations under the License.
namespace xla {
namespace gpu {
absl::Span<const stream_executor::dnn::AlgorithmDesc>
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
tensorflow::CudnnVersion cudnn_version,
const std::string& blas_version,
const std::string& hlo);
absl::Span<const stream_executor::dnn::AlgorithmDesc> GetDisabledConvAlgorithms(
tensorflow::ComputeCapability cc, tensorflow::CudnnVersion cudnn_version,
const std::string& blas_version, const std::string& hlo);
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_

View File

@ -26,22 +26,22 @@ namespace xla {
namespace gpu {
namespace {
class BlacklistTest : public testing::Test {
class DenylistTest : public testing::Test {
protected:
BlacklistTest() {
DenylistTest() {
tensorflow::setenv(
"XLA_FLAGS",
absl::StrCat(
"--xla_gpu_algorithm_blacklist_path=",
"--xla_gpu_algorithm_denylist_path=",
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
"tensorflow", "compiler", "xla", "service", "gpu", "data",
"hlo_algorithm_blacklist.pbtxt")))
"hlo_algorithm_denylist.pbtxt")))
.data(),
0);
}
};
TEST_F(BlacklistTest, DefaultTest) {
TEST_F(DenylistTest, DefaultTest) {
tensorflow::ComputeCapability cc;
cc.set_major(7);
cc.set_minor(0);
@ -49,7 +49,7 @@ TEST_F(BlacklistTest, DefaultTest) {
cudnn_version.set_major(7);
cudnn_version.set_minor(6);
cudnn_version.set_patch(2);
auto list = GetBlacklistedConvAlgorithms(
auto list = GetDisabledConvAlgorithms(
cc, cudnn_version, /*blas_version=*/"9000",
R"((f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}")");
ASSERT_EQ(4, list.size());
@ -59,7 +59,7 @@ TEST_F(BlacklistTest, DefaultTest) {
EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(1, true), list[3]);
}
TEST_F(BlacklistTest, NegativeTest) {
TEST_F(DenylistTest, NegativeTest) {
tensorflow::ComputeCapability cc;
cc.set_major(7);
cc.set_minor(0);
@ -68,7 +68,7 @@ TEST_F(BlacklistTest, NegativeTest) {
cudnn_version.set_minor(6);
cudnn_version.set_minor(2);
auto list =
GetBlacklistedConvAlgorithms(cc, cudnn_version, "9000", R"(invalid hlo)");
GetDisabledConvAlgorithms(cc, cudnn_version, "9000", R"(invalid hlo)");
ASSERT_EQ(0, list.size());
}

View File

@ -231,7 +231,6 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
<< " of " << hlo.ToString();
llvm_ir::IrArray ir_array(base_ptr,
ShapeUtil::GetSubshape(hlo.shape(), shape_index));
alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index);
// The GPU backend emits one kernel per top-level HLO, and LLVM views
// execution of one kernel as the "whole program" executed on the GPU.

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
namespace xla {
@ -42,8 +41,7 @@ class HloToIrBindings {
: buffer_assignment_(buffer_assignment),
is_nested_(is_nested),
b_(b),
module_(llvm_module),
alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {}
module_(llvm_module) {}
void EmitBasePointersForHlos(
absl::Span<const HloInstruction* const> io_hlos,
@ -116,8 +114,6 @@ class HloToIrBindings {
// The address of the memory block that contains all temporary buffers.
llvm::Value* temp_buffer_base_ = nullptr;
llvm_ir::AliasAnalysis alias_analysis_;
};
} // namespace gpu

View File

@ -1747,6 +1747,25 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
auto buffers_it = non_constant_buffers.begin();
for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
kernel_args[*buffers_it] = arg_it;
// Annotate all allocations with LLVM's `noalias`.
// There are three kinds of allocations:
// * Read-only allocations, aka input parameters that are not aliased with
// outputs.
// * Read-write allocations, including all output buffers, some of which
// may alias with input HLO parameters, but aliased HLO buffers are always
// assigned with the same allocation.
// * The temp buffer.
//
// Read-only allocations may overlap with each other, but since they are
// not mutated, they can always be annotated with `noalias` per LLVM
// semantics.
//
// Read-write allocations and the temp buffer don't overlap with any
// allocations, therefore they can also be annotated with `noalias`.
kernel->addParamAttr(
arg_it->getArgNo(),
llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias));
}
}

View File

@ -45,7 +45,7 @@ ENTRY main {
)";
CompileAndVerifyIr(hlo_string, R"(
CHECK: @fusion(i8* align 64 dereferenceable(600) %alloc0, i8* align 16 dereferenceable(400) %alloc1, i8* align 64 dereferenceable(864) %temp_buf)
CHECK: @fusion(i8* noalias align 64 dereferenceable(600) %alloc0, i8* noalias align 16 dereferenceable(400) %alloc1, i8* noalias align 64 dereferenceable(864) %temp_buf)
)");
}

View File

@ -51,16 +51,9 @@ TEST_F(GpuNoAliasTest, Concat) {
hlo_module->AddEntryComputation(std::move(computation));
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK: %[[x_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %x{{.*}}, i32 0
; CHECK: load float, float* %[[x_gep]], {{.*}}, !noalias ![[param_noalias:.*]]
; CHECK: %[[y_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %y{{.*}}, i32 0
; CHECK: load float, float* %[[y_gep]], {{.*}}, !noalias ![[param_noalias]]
; CHECK: %[[result_ptr:.*]] = bitcast [2 x [6 x float]]* %fusion{{.*}} to float*
; CHECK: %[[result_gep:.*]] = getelementptr inbounds float, float* %[[result_ptr]]
; CHECK: store float {{.*}}, float* %[[result_gep]], align 4, !alias.scope ![[param_noalias]]
; CHECK: ![[param_noalias]] = !{![[retval_buffer:.*]]}
)",
R"(CHECK-LABEL: define void @fusion
CHECK-SAME: i8* noalias align {{[0-9]*}} dereferenceable({{[0-9]*}}) %[[OUTPUT_ALLOC:[a-z0-9]*]]
CHECK: %fusion.raw = {{.*}} %[[OUTPUT_ALLOC]])",
/*match_optimized_ir=*/false);
}

View File

@ -1,6 +1,6 @@
// RUN: hlo_to_llvm_ir %s | FileCheck %s
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) {
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
// CHECK: entry:
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
@ -26,7 +26,7 @@
// CHECK: ret void
// CHECK: scatter_TensorFlowScatterV1.in_bounds-true: ; preds = %[[VAL_24]]
// CHECK: %[[VAL_25:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_8]], i32 0, i32 %[[VAL_19]]
// CHECK: %[[VAL_26:.*]] = load i32, i32* %[[VAL_25]], align 4, !invariant.load !4, !noalias !5
// CHECK: %[[VAL_26:.*]] = load i32, i32* %[[VAL_25]], align 4, !invariant.load !4
// CHECK: %[[VAL_27:.*]] = add i32 0, %[[VAL_26]]
// CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_26]], 3
// CHECK: %[[VAL_29:.*]] = and i1 true, %[[VAL_28]]
@ -37,7 +37,7 @@
// CHECK: %[[VAL_31:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_18]]
// CHECK: %[[VAL_33:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_11]] to i32*
// CHECK: %[[VAL_34:.*]] = getelementptr inbounds i32, i32* %[[VAL_33]], i32 %[[VAL_15]]
// CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4, !noalias !5
// CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4
// CHECK: store i32 %[[VAL_35]], i32* %[[VAL_32]], align 4
// CHECK: %[[VAL_36:.*]] = load i32, i32* %[[VAL_32]], align 4
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
@ -48,9 +48,6 @@
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{i32 0, i32 6}
// CHECK: !4 = !{}
// CHECK: !5 = !{!6}
// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7}
// CHECK: !7 = !{!"XLA global AA domain"}
HloModule TensorFlowScatterV1
@ -75,7 +72,7 @@ ENTRY main {
// -----
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* align 64 dereferenceable(4) %alloc0, i8* align 16 dereferenceable(4) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 %alloc3) {
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) {
// CHECK: entry:
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
@ -101,7 +98,7 @@ ENTRY main {
// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_59]], %[[VAL_55]]
// CHECK: br label %[[VAL_56]]
// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_55]]
// CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3, !noalias !4
// CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3
// CHECK: store i32 %[[VAL_61]], i32* %[[VAL_60]], align 4
// CHECK: %[[VAL_62:.*]] = load i32, i32* %[[VAL_60]], align 4
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
@ -111,9 +108,6 @@ ENTRY main {
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{}
// CHECK: !4 = !{!5}
// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:4}", !6}
// CHECK: !6 = !{!"XLA global AA domain"}
HloModule ScatterIntoScalar
@ -137,7 +131,7 @@ ENTRY main {
// -----
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) {
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
@ -164,7 +158,7 @@ ENTRY main {
// CHECK: ret void
// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-true: ; preds = %[[VAL_89]]
// CHECK: %[[VAL_90:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_73]], i32 0, i32 %[[VAL_84]]
// CHECK: %[[VAL_91:.*]] = load i32, i32* %[[VAL_90]], align 4, !invariant.load !4, !noalias !5
// CHECK: %[[VAL_91:.*]] = load i32, i32* %[[VAL_90]], align 4, !invariant.load !4
// CHECK: %[[VAL_92:.*]] = add i32 0, %[[VAL_91]]
// CHECK: %[[VAL_93:.*]] = icmp ult i32 %[[VAL_91]], 3
// CHECK: %[[VAL_94:.*]] = and i1 true, %[[VAL_93]]
@ -175,7 +169,7 @@ ENTRY main {
// CHECK: %[[VAL_97:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_67]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_83]]
// CHECK: %[[VAL_99:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_76]] to i32*
// CHECK: %[[VAL_100:.*]] = getelementptr inbounds i32, i32* %[[VAL_99]], i32 %[[VAL_80]]
// CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4, !noalias !5
// CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4
// CHECK: store i32 %[[VAL_101]], i32* %[[VAL_98]], align 4
// CHECK: %[[VAL_102:.*]] = load i32, i32* %[[VAL_98]], align 4
// CHECK: %[[VAL_103:.*]] = load i32, i32* %[[VAL_97]], align 4
@ -199,15 +193,6 @@ ENTRY main {
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{i32 0, i32 6}
// CHECK: !4 = !{}
// CHECK: !5 = !{!6}
// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7}
// CHECK: !7 = !{!"XLA global AA domain"}
// CHECK: !8 = !{!9}
// CHECK: !9 = !{!"buffer: {index:4, offset:0, size:4}", !7}
// CHECK: !10 = !{!11}
// CHECK: !11 = !{!"buffer: {index:6, offset:0, size:4}", !7}
// CHECK: !12 = !{!13}
// CHECK: !13 = !{!"buffer: {index:5, offset:0, size:4}", !7}
HloModule TensorFlowScatter_Mul
@ -231,7 +216,7 @@ ENTRY main {
// -----
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* align 64 dereferenceable(16) %alloc0, i8* align 16 dereferenceable(16) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 dereferenceable(4) %alloc3) {
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) {
// CHECK: entry:
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
@ -253,7 +238,7 @@ ENTRY main {
// CHECK: scatter_ScalarUpdate.in_bounds-after: ; preds = %[[VAL_138:.*]], %[[VAL_139:.*]]
// CHECK: ret void
// CHECK: scatter_ScalarUpdate.in_bounds-true: ; preds = %[[VAL_139]]
// CHECK: %[[VAL_140:.*]] = load i32, i32* %[[VAL_126]], align 4, !invariant.load !3, !noalias !4
// CHECK: %[[VAL_140:.*]] = load i32, i32* %[[VAL_126]], align 4, !invariant.load !3
// CHECK: %[[VAL_141:.*]] = add i32 0, %[[VAL_140]]
// CHECK: %[[VAL_142:.*]] = icmp ult i32 %[[VAL_140]], 4
// CHECK: %[[VAL_143:.*]] = and i1 true, %[[VAL_142]]
@ -262,7 +247,7 @@ ENTRY main {
// CHECK: br label %[[VAL_137]]
// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_136]]
// CHECK: %[[VAL_145:.*]] = getelementptr inbounds [4 x i32], [4 x i32]* %[[VAL_120]], i32 0, i32 %[[VAL_141]]
// CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3, !noalias !4
// CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3
// CHECK: store i32 %[[VAL_147]], i32* %[[VAL_146]], align 4
// CHECK: %[[VAL_148:.*]] = load i32, i32* %[[VAL_146]], align 4
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
@ -272,9 +257,6 @@ ENTRY main {
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{}
// CHECK: !4 = !{!5}
// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:16}", !6}
// CHECK: !6 = !{!"XLA global AA domain"}
HloModule ScalarUpdate

View File

@ -2189,6 +2189,27 @@ Status HloInstruction::ReplaceOperandWithDifferentShape(
return Status::OK();
}
Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
HloInstruction* new_producer) {
TF_RET_CHECK(
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
<< shape() << " is not compatible with " << new_producer->shape();
return ReplaceAllUsesWithDifferentShape(users, new_producer);
}
Status HloInstruction::ReplaceAllUsesWithDifferentShape(
absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(ReplaceUseWith(user, new_producer));
}
if (parent_ && parent_->root_instruction() == this) {
parent_->set_root_instruction(new_producer,
/*accept_different_shape=*/true);
}
return Status::OK();
}
Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
TF_RET_CHECK(
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))

View File

@ -1201,6 +1201,12 @@ class HloInstruction {
// Same as ReplaceAllUsesWith, but new_producer can have a different shape.
Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer);
// Same as ReplaceAllUsesWith, but only replace given set of users.
Status ReplaceUsesWith(absl::Span<HloInstruction* const> users,
HloInstruction* new_producer);
Status ReplaceAllUsesWithDifferentShape(
absl::Span<HloInstruction* const> users, HloInstruction* new_producer);
// Performs a postorder DFS visit using this node as the root. If
// call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
// complete. If ignore_control_predecessors is true, instructions only

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include <algorithm>
#include <iterator>
#include <set>
#include <sstream>
@ -650,30 +651,28 @@ bool CompareComputationsByContent(HloComputation* a, HloComputation* b) {
} // anonymous namespace
std::vector<HloComputation*> HloModule::MakeComputationSorted() const {
std::vector<HloComputation*> result;
result.reserve(computations_.size());
for (const auto& computation : computations_) {
result.push_back(computation.get());
std::vector<HloComputation*> result = MakeComputationPostOrder();
if (config().content_aware_computation_sorting()) {
absl::c_sort(result, CompareComputationsByContent);
}
std::sort(result.begin(), result.end(), CompareComputationsByContent);
return result;
}
std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
std::vector<HloComputation*> result;
for (auto* c : computations()) {
if (c->IsFusionComputation()) {
continue;
}
result.push_back(c);
}
std::vector<HloComputation*> result = MakeComputationPostOrder();
result.erase(std::remove_if(
result.begin(), result.end(),
[](HloComputation* c) { return c->IsFusionComputation(); }),
result.end());
return result;
}
std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted()
const {
auto result = MakeNonfusionComputations();
std::sort(result.begin(), result.end(), CompareComputationsByContent);
if (config().content_aware_computation_sorting()) {
absl::c_sort(result, CompareComputationsByContent);
}
return result;
}

View File

@ -188,6 +188,14 @@ class HloModuleConfig {
alias_passthrough_params_ = alias_passthrough_params;
}
bool content_aware_computation_sorting() const {
return content_aware_computation_sorting_;
}
void set_content_aware_computation_sorting(
bool content_aware_computation_sorting) {
content_aware_computation_sorting_ = content_aware_computation_sorting;
}
FusionConfigCollection fusion_config_collection() const {
return fusion_config_collection_;
}
@ -251,6 +259,8 @@ class HloModuleConfig {
bool alias_passthrough_params_ = false;
bool content_aware_computation_sorting_ = false;
FusionConfigCollection fusion_config_collection_ =
FusionConfigCollection::kOff;

View File

@ -121,9 +121,9 @@ struct Item {
bool placed = false;
// To avoid an infinite loop rematerializing the same set of
// instructions ad infinitum, keep a blacklist of instructions
// instructions ad infinitum, keep a denylist of instructions
// which should not be rematerialized.
bool blacklisted = false;
bool denylisted = false;
// The buffers defined by this instruction.
BufferIdList buffers_defined;
@ -292,8 +292,8 @@ class InstructionList {
InsertBeforeInstructions(to_insert, {max_position_item->next});
}
void Blacklist(const HloInstruction* inst) {
GetItem(inst)->blacklisted = true;
void Denylist(const HloInstruction* inst) {
GetItem(inst)->denylisted = true;
}
private:
@ -1158,13 +1158,13 @@ std::vector<Item*> GetInitialBlock(const InstructionList& instruction_list,
return item_block;
}
// Returns whether any instruction in 'block' is blacklisted or
// Returns whether any instruction in 'block' is denylisted or
// non-rematerializable.
bool AnyBlacklistedOrNonRematerializable(
bool AnyDenylistedOrNonRematerializable(
const std::vector<Item*>& block,
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
for (auto* item : block) {
if (item->blacklisted) {
if (item->denylisted) {
return true;
}
if (!CanBeRematerialized(item->instruction, rematerializable_map)) {
@ -1195,10 +1195,10 @@ MemoryUsageTracker::PickRematerializationCandidates(
// instructions.
break;
}
// If any item in the starting block are blacklisted or non-rematable, then
// If any item in the starting block are denylisted or non-rematable, then
// break and move on to next start_item (we can actually move to the last
// invalid item in this block, but let's ignore that optimization for now).
if (AnyBlacklistedOrNonRematerializable(block, rematerializable_map)) {
if (AnyDenylistedOrNonRematerializable(block, rematerializable_map)) {
continue;
}
while (block.size() <= max_block_size) {
@ -1289,8 +1289,8 @@ MemoryUsageTracker::PickRematerializationCandidates(
// Time to update the block to include the next instruction.
auto* last_item = block[block.size() - 1];
auto* next_item = instruction_list.next(last_item);
if (next_item == nullptr || next_item->blacklisted ||
!next_item->placed || next_item == in_progress_item_ ||
if (next_item == nullptr || next_item->denylisted || !next_item->placed ||
next_item == in_progress_item_ ||
!CanBeRematerialized(next_item->instruction, rematerializable_map)) {
break;
}
@ -1404,7 +1404,7 @@ StatusOr<int64> RematerializeInstructions(
// instruction it was a copying of. Now 'remat' is a rematerialization
// of 'best' and kills 'best'. Stop rematerializing this instruction
// to avoid an infinite loop.
instruction_list->Blacklist(remat);
instruction_list->Denylist(remat);
}
remat_move_instructions->insert(remat);
} else {
@ -1460,8 +1460,8 @@ StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
place_before.push_back(instruction_list->GetItem(user));
}
instruction_list->Blacklist(compressed_item->instruction);
instruction_list->Blacklist(uncompressed_item->instruction);
instruction_list->Denylist(compressed_item->instruction);
instruction_list->Denylist(uncompressed_item->instruction);
instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
@ -1583,7 +1583,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// rematerialization is added to 'remat_move_instructions' (the
// rematerialization is essentially a move). If the next rematerialization of
// the instruction is also a move then the rematerialization is added to the
// blacklist.
// denylist.
absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
// The map from instructions to their rematerializable status.

View File

@ -17,6 +17,8 @@ package_group(
cc_library(
name = "spmd_partitioner",
srcs = [
"convolution_handler.cc",
"dot_handler.cc",
"spmd_partitioner.cc",
"spmd_partitioner_util.cc",
],

View File

@ -0,0 +1,695 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/numbers.h"
namespace xla {
namespace spmd {
Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
HloInstruction* hlo) {
TF_RET_CHECK(hlo->opcode() == HloOpcode::kConvolution);
auto lhs = GetPartitionedHlo(hlo->operand(0));
auto rhs = GetPartitionedHlo(hlo->operand(1));
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
!rhs.sharding().IsTileMaximal());
const auto& dnums = hlo->convolution_dimension_numbers();
// Check if the operand shardings are aligned. Also we currently don't
// support partitioning non-spatial dimensions.
std::vector<int64> rhs_to_lhs_indices(hlo->shape().rank());
rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
dnums.input_batch_dimension();
rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
dnums.input_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
dnums.input_spatial_dimensions(i);
}
std::vector<int64> lhs_to_rhs_indices(hlo->shape().rank());
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
}
Window window = hlo->window();
std::vector<int64> reversed_rhs_dims;
for (int64 i = 0; i < window.dimensions_size(); ++i) {
if (window.dimensions(i).window_reversal()) {
reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i));
}
}
if (!reversed_rhs_dims.empty()) {
// Make the reversed dims left-padded to prepare for window reversal.
auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims);
if (left_padded_rhs == nullptr) {
return DefaultAction(hlo);
}
left_padded_rhs->set_sharding(rhs.sharding());
rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state());
}
// Consider window reversal when resharding RHS or LHS. Note: this will not
// reverse the data in the shard. We use window reversal to do that.
auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding(
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices),
reversed_rhs_dims);
auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(
hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims),
lhs_to_rhs_indices);
auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
const HloSharding& rhs_sharding) {
return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
1 ||
rhs_sharding.tile_assignment().dim(
dnums.kernel_output_feature_dimension()) != 1;
};
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
return DefaultAction(hlo);
}
lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
rhs = rhs.PadWithValue(zero, reversed_rhs_dims);
} else {
if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
return DefaultAction(hlo);
}
lhs = lhs.PadWithValue(zero);
rhs =
rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
}
// Reshard LHS by exchanging halo such that each shard computes the partial
// sum of the full shape result, and add AllReduce.
//
// The size of halo on each dimension can be calculated from the projection
// onto the LHS that each RHS shard i needs to read. RHS and LHS below refers
// to the shard size of RHS and LHS, WC is the number of windows, and D is the
// window dilation.
//
// * offset(i): RHS * D * i - low_padding
// * limit(i): {(RHS - 1) * D + 1} * (i + 1) + (WC - 1) * stride - low_padding
//
// Since shard i has LHS of range [i * LHS, (i + 1) * LHS)
// * left-halo: i * LHS - offset(i)
// = (LHS - RHS) * i + low_padding
// * right-halo: limit(i) - (i + 1) * LHS
// = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding
std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
auto wd = window.dimensions(i);
if (wd.base_dilation() != 1) {
return DefaultAction(hlo);
}
int64 lhs_shard_size =
CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
int64 rhs_shard_size =
CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
shard_counts[i] = shard_count;
lhs_shard_sizes[i] = lhs_shard_size;
rhs_shard_sizes[i] = rhs_shard_size;
}
std::vector<OffsetCalculation> left_halo_size_functions(hlo->shape().rank());
std::vector<OffsetCalculation> right_halo_size_functions(hlo->shape().rank());
Window new_window = window;
auto partition_ordinals =
MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_);
HloInstruction* lhs_with_halo = lhs.hlo();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 lhs_shard_size = lhs_shard_sizes[i];
int64 rhs_shard_size = rhs_shard_sizes[i];
if (shard_counts[i] == 1) {
continue;
}
// Calculate the left and right halo sizes as described in the comments
// above.
auto wd = window.dimensions(i);
int64 padding_low = wd.padding_low();
int64 padding_high = wd.padding_high();
int64 base = lhs.base_shape().dimensions(lhs_dimension);
int64 window_count = 1 + (padding_low + padding_high + base -
(1 + (wd.size() - 1) * wd.window_dilation())) /
wd.stride();
int64 rhs_shard_size_dilated =
(rhs_shard_size - 1) * wd.window_dilation() + 1;
left_halo_size_functions[lhs_dimension] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low,
1));
right_halo_size_functions[lhs_dimension] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
rhs_shard_size_dilated - lhs_shard_size,
rhs_shard_size_dilated - lhs_shard_size +
wd.stride() * (window_count - 1) - padding_low,
1));
// Exchange halo and concatenate.
int64 dim = dnums.input_spatial_dimensions(i);
int64 explicit_left_padding_on_full_shape = padding_low;
int64 shard_size_with_halo =
wd.stride() * (window_count - 1) + rhs_shard_size_dilated;
new_window.mutable_dimensions(i)->set_padding_low(0);
new_window.mutable_dimensions(i)->set_padding_high(0);
new_window.mutable_dimensions(i)->set_size(rhs_shard_size);
// offset_on_padded_shape and padded_full_shape_size are needed only if
// we want to mask out-of-range values in ExchangeHaloAndGetValidData().
// Since the default value for both the collective-permute is zero and
// also we call PadWithValue() on both operands at the beginning, we
// don't need to mask here.
//
// TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
// if it's always safe.
auto offset_on_padded_shape =
OffsetCalculation(MultiplyAddDivideOffsetCalculation());
int64 padded_full_shape_size = 0;
auto concat = ExchangeHaloAndGetValidData(
lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim],
right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(),
offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), zero,
partition_ordinals[dim], collective_ops_creator_, next_channel_id_, &b_,
/*mask_invalid_region=*/false);
if (!concat) {
return DefaultAction(hlo);
}
lhs_with_halo = *concat;
}
SetPartitionedHlo(hlo, [&]() {
auto conv = b_.AddInstruction(HloInstruction::CreateConvolve(
hlo->shape(), lhs_with_halo, rhs.hlo(), hlo->feature_group_count(),
hlo->batch_group_count(), new_window,
hlo->convolution_dimension_numbers(), hlo->precision_config()));
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo);
if (dot_dnums) {
// Use HandleDotHelper() for convs that are actually einsums.
spmd::DotGeneralDimsMapping mapping;
for (const auto& dims : dot_dnums->batch_dims) {
mapping.batch_dims.emplace_back();
mapping.batch_dims.back().lhs = dims.lhs;
mapping.batch_dims.back().rhs = dims.rhs;
mapping.batch_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->contracting_dims) {
mapping.contracting_dims.emplace_back();
mapping.contracting_dims.back().lhs = dims.lhs;
mapping.contracting_dims.back().rhs = dims.rhs;
mapping.contracting_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->lhs_non_contracting_dims) {
mapping.lhs_non_contracting_dims.emplace_back();
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.lhs_non_contracting_dims.back().output = dims.output;
}
for (const auto& dims : dot_dnums->rhs_non_contracting_dims) {
mapping.rhs_non_contracting_dims.emplace_back();
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
mapping.rhs_non_contracting_dims.back().output = dims.output;
}
auto create_sharded_conv =
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_conv,
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
*hlo, *dot_dnums, lhs_hlo, rhs_hlo));
return b->AddInstruction(std::move(sharded_conv));
};
return HandleDotHelper(hlo, mapping, create_sharded_conv);
}
auto lhs = GetPartitionedHlo(hlo->operand(0));
auto rhs = GetPartitionedHlo(hlo->operand(1));
const HloSharding& sharding = hlo->sharding();
const auto& dnums = hlo->convolution_dimension_numbers();
std::vector<int64> rhs_to_lhs_indices(hlo->shape().rank());
rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
dnums.input_batch_dimension();
rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
dnums.input_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
dnums.input_spatial_dimensions(i);
}
std::vector<int64> lhs_to_rhs_indices(hlo->shape().rank());
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
}
auto aligned_rhs_sharding =
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
auto aligned_lhs_sharding =
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
// Handling cases where all the partitioned dimensions are parallel
// dimensions.
int64 lhs_parallel_dim_partitions = 1;
int64 rhs_parallel_dim_partitions = 1;
std::vector<int64> parallel_spatial_dims;
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dim = dnums.input_spatial_dimensions(i);
int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
const auto& wd = hlo->window().dimensions(i);
int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
// Only non reversal window is supported right now.
if (!wd.window_reversal() &&
dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
parallel_spatial_dims.emplace_back(i);
lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim);
rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim);
}
}
bool lhs_partition_dims_are_parallel =
(lhs_parallel_dim_partitions == num_partitions_);
bool rhs_partition_dims_are_parallel =
(rhs_parallel_dim_partitions == num_partitions_);
// If there is a parallel dim and all the partitioned dimensions are parallel
// dimensions in either LHS or RHS, simply create partitioned convolutions.
if (!parallel_spatial_dims.empty() &&
(lhs_partition_dims_are_parallel || rhs_partition_dims_are_parallel)) {
// Reshard LHS or RHS to partition at parallel dimensions as the other
// operand.
if (lhs_partition_dims_are_parallel) {
rhs = rhs.Reshard(aligned_rhs_sharding);
} else {
lhs = lhs.Reshard(aligned_lhs_sharding);
}
auto lhs_shard_shape =
MakePartitionedShape(lhs.base_shape(), lhs.sharding());
auto rhs_shard_shape =
MakePartitionedShape(rhs.base_shape(), rhs.sharding());
// Update convolution window.
auto new_window = hlo->window();
for (const auto& spatial_dim : parallel_spatial_dims) {
auto wd = new_window.mutable_dimensions(spatial_dim);
wd->set_size(lhs_shard_shape.dimensions(
dnums.input_spatial_dimensions(spatial_dim)));
wd->set_stride(std::max<int64>(1, wd->size() - 1));
wd->set_base_dilation(wd->size());
}
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
lhs_shard_shape, rhs_shard_shape, hlo->feature_group_count(),
hlo->batch_group_count(), new_window, dnums));
*sharded_conv_shape.mutable_layout() = hlo->shape().layout();
SetPartitionedHlo(hlo, [&]() {
auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
sharded_conv_shape, lhs.hlo(), rhs.hlo(), hlo->feature_group_count(),
hlo->batch_group_count(), new_window, dnums,
hlo->precision_config()));
sharded_conv->set_sharding(hlo->sharding());
return PartitionedHlo(sharded_conv, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
// Handling cases where both operands' shardings are aligned. We check that
// the LHS batch dimension is not partitioned because it is mapped to the
// output feature dimension in aligned_rhs_sharding, which are not the same
// dimension.
if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) {
if (options_.conv_halo_exchange_always_on_lhs) {
return HandleConvolutionTiledLhsAndRhs(hlo);
} else {
// Reshard RHS so that each shard computes the partial sum of the full
// shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs()
// that reshards LHS.
//
// The size of halo on each dimension can be calculated from the
// projection onto the RHS that shard i needs to read. RHS and LHS below
// refers to the shard size of RHS and LHS, WC is the number of windows,
// and D is the window dilation.
//
// * offset(i): LHS * i + low_padding - (WC - 1) * stride
// * limit(i): LHS * (i + 1) + low_padding
//
// Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D)
// * left-halo: i * RHS - offset(i)
// = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding
// * right-halo: limit(i) - (i + 1) * RHS
// = (i + 1) * (LHS - RHS * D) + low_pading
auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
const HloSharding& rhs_sharding) {
// We currently don't support partitioning input batch or output feature
// dimensions.
return lhs_sharding.tile_assignment().dim(
dnums.input_batch_dimension()) != 1 ||
rhs_sharding.tile_assignment().dim(
dnums.kernel_output_feature_dimension()) != 1;
};
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
if (ShapeSizeInBytes(lhs.base_shape()) <
ShapeSizeInBytes(rhs.base_shape())) {
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
return DefaultAction(hlo);
}
lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
rhs = rhs.PadWithValue(zero);
} else {
if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
return DefaultAction(hlo);
}
lhs = lhs.PadWithValue(zero);
rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero);
}
Window window = hlo->window();
std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
auto wd = window.dimensions(i);
if (wd.base_dilation() != 1 || wd.window_reversal()) {
return DefaultAction(hlo);
}
int64 lhs_shard_size = CeilOfRatio(
lhs.base_shape().dimensions(lhs_dimension), shard_count);
int64 rhs_shard_size = CeilOfRatio(
rhs.base_shape().dimensions(rhs_dimension), shard_count);
shard_counts[i] = shard_count;
lhs_shard_sizes[i] = lhs_shard_size;
rhs_shard_sizes[i] = rhs_shard_size;
}
std::vector<OffsetCalculation> left_halo_size_functions(
hlo->shape().rank());
std::vector<OffsetCalculation> right_halo_size_functions(
hlo->shape().rank());
Window new_window = window;
// Data structures needed for Pad and DynamicSlice on LHS if needed.
bool need_dynamic_slice_lhs = false;
auto partition_ordinals =
MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_);
std::vector<int64> zero_padding(hlo->shape().rank());
PaddingConfig pad_config =
window_util::MakeSymmetricPadding(zero_padding);
auto zero_s32 = b_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
std::vector<HloInstruction*> dynamic_slice_start_indices(
hlo->shape().rank(), zero_s32);
Shape dynamic_slice_shape = lhs.hlo()->shape();
Shape pad_shape = lhs.hlo()->shape();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
int64 lhs_shard_size = lhs_shard_sizes[i];
int64 rhs_shard_size = rhs_shard_sizes[i];
if (shard_counts[i] == 1) {
continue;
}
// Calculate the left and right halo sizes as described in the comments
// above. It calculcates the halo sizes with dilation, so we apply
// CeilOfRatio({left,right}_halo_size, window_dilation).
auto wd = window.dimensions(i);
int64 padding_low = wd.padding_low();
int64 padding_high = wd.padding_high();
int64 base = lhs.base_shape().dimensions(lhs_dimension);
int64 window_count =
1 + (padding_low + padding_high + base -
(1 + (wd.size() - 1) * wd.window_dilation())) /
wd.stride();
left_halo_size_functions[rhs_dimension] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
rhs_shard_size * wd.window_dilation() - lhs_shard_size,
(window_count - 1) * wd.stride() - padding_low +
wd.window_dilation() - 1,
wd.window_dilation()));
right_halo_size_functions[rhs_dimension] =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
lhs_shard_size - rhs_shard_size * wd.window_dilation(),
lhs_shard_size - rhs_shard_size * wd.window_dilation() +
padding_low + wd.window_dilation() - 1,
wd.window_dilation()));
// New RHS window size includes the maximum of both left and right
// halos.
int64 halo_size = left_halo_size_functions[rhs_dimension].MaxInRange(
1, shard_counts[i]) +
right_halo_size_functions[rhs_dimension].MaxInRange(
0, shard_counts[i] - 1);
int64 new_window_size =
rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size;
// The amount of new low padding could be dynamic (e.g., window_dilation
// != 1), which requires pad (to the maximum) and dynamic slice on LHS.
//
// If we consider the first window, the offset of the dilated RHS that
// aligns with the first valid LHS element for shard i is 'padding_low +
// LHS * i'. When the left halo is added to RHS, the offset of the first
// RHS element is (RHS * i - left_halo) * window_dilation. The
// difference between the two values is the amount of padding_low we
// need on LHS.
auto new_padding_low_function =
OffsetCalculation(
HloOpcode::kMultiply, left_halo_size_functions[rhs_dimension],
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
0, wd.window_dilation(), 1))) -
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
rhs_shard_size * wd.window_dilation() - lhs_shard_size,
-padding_low, 1));
int64 new_padding_low_max =
new_padding_low_function.MaxInRange(0, shard_counts[i]);
int64 new_padding_low = new_padding_low_max;
int64 new_padding_high = window_count * wd.stride() +
(new_window_size - 1) * wd.window_dilation() -
new_padding_low - lhs_shard_size;
// We do pad/dynamic-slice only when the padding is dynamic.
if (!new_padding_low_function.IsConstant()) {
need_dynamic_slice_lhs = true;
new_padding_low = 0;
pad_config.mutable_dimensions(lhs_dimension)
->set_edge_padding_low(new_padding_low_max);
pad_config.mutable_dimensions(lhs_dimension)
->set_edge_padding_high(new_padding_low_max);
pad_shape.set_dimensions(lhs_dimension,
lhs_shard_size + 2 * new_padding_low_max);
dynamic_slice_start_indices[lhs_dimension] =
(OffsetCalculation(MultiplyAddDivideOffsetCalculation(
0, new_padding_low_max, 1)) -
new_padding_low_function)
.Calculate(partition_ordinals[lhs_dimension], &b_);
dynamic_slice_shape.set_dimensions(
lhs_dimension, lhs_shard_size + new_padding_low_max);
}
// Since the convolution RHS operand size increased with halos, adjust
// the window config accordingly.
new_window.mutable_dimensions(i)->set_padding_low(new_padding_low);
new_window.mutable_dimensions(i)->set_padding_high(new_padding_high);
new_window.mutable_dimensions(i)->set_size(
rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size);
}
HloInstruction* conv_lhs = lhs.hlo();
if (need_dynamic_slice_lhs) {
auto pad = b_.AddInstruction(
HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config));
conv_lhs = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
dynamic_slice_shape, pad, dynamic_slice_start_indices,
dynamic_slice_shape.dimensions()));
}
// Exchange halo and concatenate.
HloInstruction* rhs_with_halo = rhs.hlo();
for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
int64 dim = dnums.kernel_spatial_dimensions(i);
int64 explicit_left_padding_on_full_shape =
left_halo_size_functions[dim].Calculate(0);
int64 shard_size_with_halo = new_window.dimensions(i).size();
// offset_on_padded_shape and padded_full_shape_size are needed only if
// we want to mask out-of-range values in ExchangeHaloAndGetValidData().
// Since the default value for both the collective-permute is zero and
// also we call PadWithValue() on both operands at the beginning, we
// don't need to mask here.
//
// TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
// if it's always safe.
auto offset_on_padded_shape =
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) -
left_halo_size_functions[dim];
int64 padded_full_shape_size =
offset_on_padded_shape.Calculate(shard_counts[i] - 1) +
new_window.dimensions(i).size();
auto concat = ExchangeHaloAndGetValidData(
rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim],
right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(),
offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_),
zero, partition_ordinals[dim], collective_ops_creator_,
next_channel_id_, &b_, /*mask_invalid_region=*/false);
if (!concat) {
return DefaultAction(hlo);
}
rhs_with_halo = *concat;
}
SetPartitionedHlo(hlo, [&]() {
auto conv = b_.AddInstruction(HloInstruction::CreateConvolve(
hlo->shape(), conv_lhs, rhs_with_halo, hlo->feature_group_count(),
hlo->batch_group_count(), new_window, dnums,
hlo->precision_config()));
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
if (!sharding.IsTileMaximal()) {
// We don't currently support sharding on output feature dimension.
if (sharding.tile_assignment().dim(dnums.output_feature_dimension()) > 1) {
return DefaultAction(hlo);
}
// Check if the operand and the output sharding are aligned.
std::vector<int64> input_to_output_indices(hlo->shape().rank());
input_to_output_indices[dnums.input_batch_dimension()] =
dnums.output_batch_dimension();
input_to_output_indices[dnums.input_feature_dimension()] =
dnums.output_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
input_to_output_indices[dnums.input_spatial_dimensions(i)] =
dnums.output_spatial_dimensions(i);
}
auto target_operand_sharding =
hlo_sharding_util::TransposeSharding(sharding, input_to_output_indices);
lhs = lhs.Reshard(target_operand_sharding);
// Replicate the RHS.
rhs = rhs.Reshard(HloSharding::Replicate());
// Convolution window config does not include batch and feature dimensions,
// whereas ReshardAsWindowedInput() expects the same number of window
// dimensions as the rank of the operand. So add two more trivial
// dimensions.
std::vector<int64> ones(hlo->shape().rank(), 1);
auto operand_window = window_util::MakeWindow(ones);
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
*operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) =
hlo->window().dimensions(i);
}
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
auto resharded_operand_and_window = lhs.ReshardAsWindowedInput(
operand_window, target_operand_sharding, zero);
if (!resharded_operand_and_window.has_value()) {
return DefaultAction(hlo);
}
Window new_window;
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
*new_window.add_dimensions() =
resharded_operand_and_window->shard_window.dimensions(
dnums.input_spatial_dimensions(i));
}
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
resharded_operand_and_window->sharded_input->shape(),
rhs.hlo()->shape(), hlo->feature_group_count(),
hlo->batch_group_count(), new_window, dnums));
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
*sharded_conv_shape.mutable_layout() = shard_shape.layout();
SetPartitionedHlo(hlo, [&]() {
auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
sharded_conv_shape, resharded_operand_and_window->sharded_input,
rhs.hlo(), hlo->feature_group_count(), hlo->batch_group_count(),
new_window, dnums, hlo->precision_config()));
if (!resharded_operand_and_window->dynamic_slice_index_on_output
.has_value()) {
CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
return sharded_conv;
}
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, sharded_conv,
*resharded_operand_and_window->dynamic_slice_index_on_output,
shard_shape.dimensions()));
});
return Status::OK();
}
return DefaultAction(hlo);
}
} // namespace spmd
} // namespace xla

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -885,5 +885,46 @@ int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
return sharding.tile_assignment().dim(dim);
}
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
const HloSharding& source, const HloSharding& target) {
if (source.IsTileMaximal() || target.IsTileMaximal() ||
source.tile_assignment().num_dimensions() !=
target.tile_assignment().num_dimensions()) {
return absl::nullopt;
}
int64 source_dim = -1;
int64 target_dim = -1;
for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
if (source.tile_assignment().dim(i) > 1 &&
target.tile_assignment().dim(i) == 1) {
if (source_dim != -1) {
return absl::nullopt;
}
source_dim = i;
} else if (source.tile_assignment().dim(i) == 1 &&
target.tile_assignment().dim(i) > 1) {
if (target_dim != -1) {
return absl::nullopt;
}
target_dim = i;
} else if (source.tile_assignment().dim(i) !=
target.tile_assignment().dim(i)) {
return absl::nullopt;
}
}
if (source_dim == -1 || target_dim == -1 || source_dim == target_dim) {
return absl::nullopt;
}
return std::pair<int64, int64>(source_dim, target_dim);
}
bool CanReshardWithCollectivePermute(const HloSharding& source,
const HloSharding& target) {
return !source.IsTileMaximal() && !target.IsTileMaximal() &&
source.tile_assignment().dimensions() ==
target.tile_assignment().dimensions() &&
source.tile_assignment() != target.tile_assignment();
}
} // namespace spmd
} // namespace xla

View File

@ -265,6 +265,15 @@ HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
// Check if a dimension is sharded.
int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
// Returns the pair of source and target dimensions is the resharding can be
// done via all-to-all.
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
const HloSharding& source, const HloSharding& target);
// Returns whether the resharding can be done via collective-permute.
bool CanReshardWithCollectivePermute(const HloSharding& source,
const HloSharding& target);
} // namespace spmd
} // namespace xla

View File

@ -270,8 +270,8 @@ message DebugOptions {
// Paths to files with ptx code.
repeated string xla_gpu_ptx_file = 127;
// Blacklist for cuDNN convolutions.
string xla_gpu_algorithm_blacklist_path = 128;
// Denylist for cuDNN convolutions.
string xla_gpu_algorithm_denylist_path = 128;
// Guarantee run-to-run determinism from reductions on XLA:GPU.
bool xla_gpu_deterministic_reductions = 130;

View File

@ -163,6 +163,7 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core/platform:platform_port",
"//tensorflow/core/util:abstract_stack_trace",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",

View File

@ -306,6 +306,7 @@ Status EagerOperation::Reset(
}
attrs_.Reset(op);
use_xla_ = false;
stack_trace_.reset();
is_function_ = is_function;
cancellation_manager_ = nullptr;
executor_ = executor ? executor : &ctx_.Executor();

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/util/abstract_stack_trace.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -120,6 +121,14 @@ class EagerOperation : public ImmediateExecutionOperation {
Status SetUseXla(bool enable) override;
void SetStackTrace(AbstractStackTrace stack_trace) override {
stack_trace_ = stack_trace;
}
absl::optional<AbstractStackTrace> GetStackTrace() override {
return stack_trace_;
}
Status Reset(const char* op, const char* device_name, bool remote,
EagerExecutor* executor,
const absl::optional<EagerRemoteFunctionParams>
@ -218,6 +227,7 @@ class EagerOperation : public ImmediateExecutionOperation {
VariantDevice device_;
bool use_xla_ = false;
absl::optional<AbstractStackTrace> stack_trace_;
bool is_function_; // Conceptually const, but can't be because of Reset
bool colocation_exempt_;
CancellationManager* cancellation_manager_ = nullptr; // Not owned.

View File

@ -634,7 +634,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
auto node = absl::make_unique<AsyncExecuteNode>(
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
graph_collector, op->GetCancellationManager(),
absl::Span<TensorHandle*>(retvals, num_outputs));
absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
// Release the inputs from the eager operation since the AsyncExecuteNode
// would have taken ownership. This allows the inputs to be forwarded if
// possible.

View File

@ -150,14 +150,16 @@ class AsyncExecuteNode : public EagerNode {
core::RefCountPtr<KernelAndDevice> kernel,
GraphCollector* graph_collector,
CancellationManager* cancellation_manager,
absl::Span<TensorHandle*> retvals)
absl::Span<TensorHandle*> retvals,
absl::optional<AbstractStackTrace> stack_trace)
: EagerNode(),
ctx_(ctx),
inputs_(inputs),
remote_func_params_(remote_func_params),
kernel_(std::move(kernel)),
graph_collector_(graph_collector),
cancellation_manager_(cancellation_manager) {
cancellation_manager_(cancellation_manager),
stack_trace_(stack_trace) {
// Copy the output handles, since the container for them might get
// destroyed.
for (auto handle : retvals) {
@ -194,10 +196,14 @@ class AsyncExecuteNode : public EagerNode {
}
++i;
}
const Status status = EagerKernelExecute(
Status status = EagerKernelExecute(
ctx_, inputs_, remote_func_params_, kernel_, graph_collector_,
cancellation_manager_, absl::MakeSpan(retvals_));
if (!status.ok()) {
if (stack_trace_.has_value()) {
status = Status(status.code(), status.error_message(),
stack_trace_->ToStackFrames());
}
Abort(status);
return status;
}
@ -227,6 +233,7 @@ class AsyncExecuteNode : public EagerNode {
core::RefCountPtr<KernelAndDevice> kernel_;
GraphCollector* graph_collector_;
CancellationManager* const cancellation_manager_;
absl::optional<AbstractStackTrace> stack_trace_;
absl::InlinedVector<TensorHandle*, 2> retvals_;
};

View File

@ -159,6 +159,7 @@ tf_cuda_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor",
"//tensorflow/core/platform:tf32_utils",
"//tensorflow/core/profiler/lib:annotated_traceme",
"//tensorflow/core/profiler/lib:scoped_annotation",
"//third_party/eigen3",

View File

@ -277,6 +277,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// RendezvousMgr already aborted, shouldn't send RPC call any more
if (!call->status().ok()) {
DeregisterCall(call);
// NOTE: `*sess` can potentially be deleted before we return from
// `call->done()(...)`, so we must release the worker before calling the
// callback.

View File

@ -972,9 +972,9 @@ double Node::OutputTime(absl::flat_hash_map<string, double>* input_times,
return output_times[long_name()];
}
std::shared_ptr<Node> Node::Snapshot(std::shared_ptr<Node> output) const {
std::shared_ptr<Node> Node::Snapshot() const {
NodePairList node_pairs;
auto result = SnapshotHelper(output, &node_pairs);
auto result = SnapshotHelper(nullptr, &node_pairs);
while (!node_pairs.empty()) {
auto node_pair = node_pairs.front();
@ -1346,7 +1346,7 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget) {
std::shared_ptr<Node> snapshot;
{
tf_shared_lock lock(mu_);
snapshot = output_->Snapshot(nullptr);
snapshot = output_->Snapshot();
}
VLOG(2) << "Starting optimization of tunable parameters with GradientDescent";
auto parameters = CollectTunableParameters(snapshot);
@ -1422,7 +1422,7 @@ void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget) {
std::shared_ptr<Node> snapshot;
{
tf_shared_lock lock(mu_);
snapshot = output_->Snapshot(nullptr);
snapshot = output_->Snapshot();
}
VLOG(2) << "Starting optimization of tunable parameters with HillClimb";
const double processing_time = TotalProcessingTime(snapshot);

View File

@ -339,8 +339,7 @@ class Node {
//
// The purpose for this method is to allow the model optimization logic to
// operate over immutable state while allowing concurrent model updates.
std::shared_ptr<Node> Snapshot(std::shared_ptr<Node> output) const
TF_LOCKS_EXCLUDED(mu_);
std::shared_ptr<Node> Snapshot() const TF_LOCKS_EXCLUDED(mu_);
// Returns the per-element processing time spent in this node.
double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);

View File

@ -756,7 +756,7 @@ TEST(SnapshotTest, Model) {
cur_node = cur_node->inputs().front();
}
std::shared_ptr<Node> root_copy = root->Snapshot(nullptr);
std::shared_ptr<Node> root_copy = root->Snapshot();
cur_node = root;
std::shared_ptr<Node> cur_node_copy = root_copy;

View File

@ -293,7 +293,7 @@ class NodeTypeAttrMap {
}
// Note that the mappings generated here include inputs/outputs with fixed
// types. This makes the mappings complete (all inputs and outputs are
// included), and allows the graph rewriter to propagate black paint
// included), and allows the graph rewriter to propagate deny paint
// from/through ops with fixed types.
io2type_entry.first.reserve(input_arg_inds.size());
for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
@ -843,10 +843,10 @@ DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
}
Status ValidateLists(const gtl::FlatSet<string>& allow_list,
const gtl::FlatSet<string>& black_list,
const gtl::FlatSet<string>& gray_list,
const gtl::FlatSet<string>& deny_list,
const gtl::FlatSet<string>& infer_list,
const gtl::FlatSet<string>& clear_list) {
std::vector<gtl::FlatSet<string>> lists{allow_list, black_list, gray_list,
std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list,
clear_list};
std::multiset<string> counts;
for (const auto& list : lists) {
@ -967,23 +967,23 @@ class AutoMixedPrecisionImpl {
bool SupportsF16(const NodeTypeId& node_type) const;
const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
bool IsSourceOrSinkOp(const string& op) const;
void FindFloat32TensorListOpClustersAndBlacklistUnsafe(
void FindFloat32TensorListOpClustersAndDenylistUnsafe(
std::vector<absl::flat_hash_set<const NodeDef*>>* clusters,
absl::flat_hash_set<int>* black_set) const;
absl::flat_hash_set<int>* deny_set) const;
void FindTensorListImplicitFloat32Edges(
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
void PropagateBlackFwdThroughClearAndGray(
absl::flat_hash_set<int>* black_set) const;
void PropagateDenyFwdThroughClearAndInfer(
absl::flat_hash_set<int>* deny_set) const;
void ForceColorMatchBetweenTensorListOps(
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
absl::flat_hash_set<int>* allow_set,
absl::flat_hash_set<int>* black_set) const;
void AddClearAndGrayToAllowIfBetweenAllow(
const absl::flat_hash_set<int>& black_set,
absl::flat_hash_set<int>* deny_set) const;
void AddClearAndInferToAllowIfBetweenAllow(
const absl::flat_hash_set<int>& deny_set,
absl::flat_hash_set<int>* allow_set) const;
void PropagateAllowThroughClear(const absl::flat_hash_set<int>& black_set,
void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set,
absl::flat_hash_set<int>* allow_set) const;
Status ForceColorMatchOnRecurrentEdges(
absl::flat_hash_set<int>* allow_set) const;
@ -1006,8 +1006,8 @@ class AutoMixedPrecisionImpl {
bool force_all_fp16_;
AutoMixedPrecisionMode mode_;
gtl::FlatSet<string> f16_allowlist_;
gtl::FlatSet<string> f16_blacklist_;
gtl::FlatSet<string> f16_graylist_;
gtl::FlatSet<string> f16_denylist_;
gtl::FlatSet<string> f16_inferlist_;
gtl::FlatSet<string> f16_clearlist_;
absl::flat_hash_set<const NodeDef*> should_process_nodes_;
DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16
@ -1083,12 +1083,12 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
for (const auto& x : mp_lists->AllowList()) {
f << x << "\n";
}
f << "\nBlackList:\n";
for (const auto& x : mp_lists->BlackList()) {
f << "\nDenyList:\n";
for (const auto& x : mp_lists->DenyList()) {
f << x << "\n";
}
f << "\nGrayList:\n";
for (const auto& x : mp_lists->GrayList()) {
f << "\nInferList:\n";
for (const auto& x : mp_lists->InferList()) {
f << x << "\n";
}
f << "\nClearList:\n";
@ -1255,11 +1255,11 @@ Status AutoMixedPrecisionImpl::Optimize() {
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
get_mixed_precision_lists();
f16_allowlist_ = mp_lists->AllowList();
f16_blacklist_ = mp_lists->BlackList();
f16_graylist_ = mp_lists->GrayList();
f16_denylist_ = mp_lists->DenyList();
f16_inferlist_ = mp_lists->InferList();
f16_clearlist_ = mp_lists->ClearList();
TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_blacklist_,
f16_graylist_, f16_clearlist_));
TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
f16_inferlist_, f16_clearlist_));
size_t timestamp = Env::Default()->NowMicros() / 1000;
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
@ -1294,11 +1294,11 @@ Status AutoMixedPrecisionImpl::Optimize() {
TF_RETURN_IF_ERROR(
graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
absl::flat_hash_set<int> black_set;
absl::flat_hash_set<int> deny_set;
std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
FindFloat32TensorListOpClustersAndBlacklistUnsafe(&tensor_list_clusters,
&black_set);
FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
&deny_set);
std::vector<NodeTypeIdEdge> ephemeral_edges;
for (const auto& cluster : tensor_list_clusters) {
VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
@ -1320,14 +1320,14 @@ Status AutoMixedPrecisionImpl::Optimize() {
// This is done under the assumption that allowlist ops are always
// numerically-safe in f16 and that they are the most important ops for
// improving performance.
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka
// "blacklist" ops) or they are on a forward path from a blacklist node to
// a black/gray node (including the node at the end of the path) through
// non-numerically-dangerous ops (aka "greylist" and "clearlist" ops).
// 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
// "denylist" ops) or they are on a forward path from a denylist node to
// a deny/infer node (including the node at the end of the path) through
// non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
// This is done to prevent numerically-dangerous ops and their downstream
// effects from being changed to f16, which would risk breaking the
// numerical accuracy of the model.
// 3) For all remaining nodes that are not considered dangerous (greylist
// 3) For all remaining nodes that are not considered dangerous (inferlist
// and clearlist ops), find those that are between (i.e., both upstream
// and downstream of) allow nodes, and add them to the allow_set.
// This is done to avoid unnecessary casts between allowlist ops.
@ -1346,29 +1346,29 @@ Status AutoMixedPrecisionImpl::Optimize() {
return Status::OK();
}
VLOG(2) << "Beginning pass 2 to propagate black forwards from blacklist ops "
"through clear/graylist ops";
PropagateBlackFwdThroughClearAndGray(&black_set);
VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
"through clear/inferlist ops";
PropagateDenyFwdThroughClearAndInfer(&deny_set);
VLOG(2) << "Finished pass 2";
VLOG(2) << "Forcing color match between data structure ops";
for (const auto& cluster : tensor_list_clusters) {
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set);
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
}
VLOG(2) << "Beginning pass 3 to set clear and gray nodes to allow if they "
VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
"are between allow ops";
AddClearAndGrayToAllowIfBetweenAllow(black_set, &allow_set);
AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
VLOG(2) << "Finished pass 3";
VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
"clearlist ops";
PropagateAllowThroughClear(black_set, &allow_set);
PropagateAllowThroughClear(deny_set, &allow_set);
VLOG(2) << "Finished pass 4";
VLOG(2) << "Forcing color match between data structure ops";
for (const auto& cluster : tensor_list_clusters) {
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set);
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
}
VLOG(2) << "Forcing color match on loop edges";
@ -1426,11 +1426,11 @@ bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const {
// Finds all clusters of float32 Tensor List nodes that are connected via their
// handle edges. Unsafe clusters (those with unprocessable nodes, or with edges
// that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc.
// nodes) are added to black_set. The caller should paint all nodes in a cluster
// nodes) are added to deny_set. The caller should paint all nodes in a cluster
// the same color, as they may all refer to the same Tensor List.
void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndBlacklistUnsafe(
void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe(
std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters,
absl::flat_hash_set<int>* black_set) const {
absl::flat_hash_set<int>* deny_set) const {
absl::flat_hash_set<const NodeDef*> tensor_list_prop_set;
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
@ -1463,7 +1463,7 @@ void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndBlacklistUnsafe(
cluster.insert(node);
if (!ShouldProcess(*node)) {
// The cluster contains an un-processable node.
black_set->insert(root_fp32_idx);
deny_set->insert(root_fp32_idx);
}
// TODO(benbarsdell): In a theoretical pathological
// case of a Tensor List of Tensor List handles, the
@ -1471,7 +1471,7 @@ void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndBlacklistUnsafe(
// sink.
} else if (IsSourceOrSinkOp(node->op())) {
// The cluster crosses an untraversable boundary.
black_set->insert(root_fp32_idx);
deny_set->insert(root_fp32_idx);
}
}));
tensor_list_clusters->push_back(cluster);
@ -1534,21 +1534,21 @@ void AutoMixedPrecisionImpl::AddAllowlistOps(
}
}
// Adds nodes to black_set iff they are on the blacklist or they are on a
// forward path from a blacklist node to a black/gray node (including the node
// at the end of the path) through clear and gray nodes.
// E.g., black -> gray -> clear -> gray -> clear -> allow -> gray
// becomes: black -> black -> black -> black -> clear -> allow -> gray.
void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
absl::flat_hash_set<int>* black_set) const {
// Adds nodes to deny_set iff they are on the denylist or they are on a
// forward path from a denylist node to a deny/infer node (including the node
// at the end of the path) through clear and infer nodes.
// E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer
// becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer.
void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer(
absl::flat_hash_set<int>* deny_set) const {
if (force_all_fp16_) return;
// Find clear nodes that are upstream of black or gray.
absl::flat_hash_set<int> upstream_of_black_or_gray_set;
// Find clear nodes that are upstream of deny or infer.
absl::flat_hash_set<int> upstream_of_deny_or_infer_set;
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
if (!(f16_blacklist_.count(root.node->op()) ||
f16_graylist_.count(root.node->op()))) {
if (!(f16_denylist_.count(root.node->op()) ||
f16_inferlist_.count(root.node->op()))) {
continue;
}
DfsTypeTraversal(graph_type_view_, {&root},
@ -1556,42 +1556,42 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
DfsTypePredicates::Enter([&](int idx) -> bool {
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
return idx == root_idx ||
(!upstream_of_black_or_gray_set.count(idx) &&
(!upstream_of_deny_or_infer_set.count(idx) &&
f16_clearlist_.count(item.node->op()));
}),
DfsTypeCallbacks::PreOrder([&](int idx) {
upstream_of_black_or_gray_set.insert(idx);
upstream_of_deny_or_infer_set.insert(idx);
}));
}
// Propagate black forward through nodes in upstream_of_black_or_gray_set.
// Propagate deny forward through nodes in upstream_of_deny_or_infer_set.
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
if (black_set->count(root_idx) || !f16_blacklist_.count(root.node->op())) {
if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) {
continue;
}
DfsTypeTraversal(
graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
DfsTypePredicates::Enter([&](int idx) -> bool {
return idx == root_idx || (!black_set->count(idx) &&
upstream_of_black_or_gray_set.count(idx));
return idx == root_idx || (!deny_set->count(idx) &&
upstream_of_deny_or_infer_set.count(idx));
}),
DfsTypeCallbacks::PreOrder([&](int idx) {
bool inserted = black_set->insert(idx).second;
bool inserted = deny_set->insert(idx).second;
if (VLOG_IS_ON(2) && inserted) {
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
VLOG(2) << "Painting type " << item.type_attr.DebugString()
<< " of " << item.node->op() << " node "
<< item.node->name() << " BLACK";
<< item.node->name() << " DENY";
}
}));
}
}
void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow(
const absl::flat_hash_set<int>& black_set,
void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow(
const absl::flat_hash_set<int>& deny_set,
absl::flat_hash_set<int>* allow_set) const {
// Find clear/graylist ops that are downstream of allow ops.
// Find clear/inferlist ops that are downstream of allow ops.
absl::flat_hash_set<int> downstream_of_allow_set;
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
@ -1605,13 +1605,13 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow(
return idx == root_idx ||
(!downstream_of_allow_set.count(idx) &&
!f16_allowlist_.count(item.node->op()) &&
!black_set.count(idx) && ShouldProcess(*item.node) &&
!deny_set.count(idx) && ShouldProcess(*item.node) &&
// TODO(benbarsdell): Consider allowing propagation through
// ops that are already float16 in order to reduce the number
// of casts.
IsFloat32(item) && SupportsF16(item) &&
(f16_clearlist_.count(item.node->op()) ||
f16_graylist_.count(item.node->op())));
f16_inferlist_.count(item.node->op())));
}),
DfsTypeCallbacks::PreOrder(
[&](int idx) { downstream_of_allow_set.insert(idx); }));
@ -1645,7 +1645,7 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow(
}
void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
const absl::flat_hash_set<int>& black_set,
const absl::flat_hash_set<int>& deny_set,
absl::flat_hash_set<int>* allow_set) const {
// Propagate allow from allow nodes through clearlist ops.
absl::flat_hash_set<int> clear_prop_set;
@ -1661,7 +1661,7 @@ void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
DfsTypePredicates::Enter([&](int idx) -> bool {
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
return idx == root_idx ||
(!allow_set->count(idx) && !black_set.count(idx) &&
(!allow_set->count(idx) && !deny_set.count(idx) &&
ShouldProcess(*item.node) && IsFloat32(item) &&
SupportsF16(item) &&
(f16_clearlist_.count(item.node->op())) &&
@ -1727,14 +1727,14 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
if (allow_set->erase(merge_idx)) {
VLOG(2) << "Painting type T of Merge node "
<< graph_type_view_.GetNode(merge_idx)->node->name()
<< " BLACK to match the color of its sibling Merge nodes "
<< " DENY to match the color of its sibling Merge nodes "
"with common NextIteration node "
<< node.name();
}
}
if (allow_set->erase(nextiter_idx)) {
VLOG(2) << "Painting type T of NextIteration node " << node.name()
<< " BLACK to match the color of its output Merge node(s)";
<< " DENY to match the color of its output Merge node(s)";
}
} else {
if (allow_set->insert(nextiter_idx).second) {
@ -1751,8 +1751,8 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
absl::flat_hash_set<int>* allow_set,
absl::flat_hash_set<int>* black_set) const {
bool any_black = false;
absl::flat_hash_set<int>* deny_set) const {
bool any_deny = false;
bool any_allow = false;
std::vector<int> node_type_idxs;
node_type_idxs.reserve(tensor_list_nodes.size());
@ -1766,24 +1766,24 @@ void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
node_type_idxs.push_back(maybe_node_type_idx.value());
}
for (int node_type_idx : node_type_idxs) {
if (black_set->count(node_type_idx)) {
any_black = true;
if (deny_set->count(node_type_idx)) {
any_deny = true;
break;
} else if (allow_set->count(node_type_idx)) {
any_allow = true;
}
}
if (!any_black && !any_allow) return;
if (!any_deny && !any_allow) return;
for (int node_type_idx : node_type_idxs) {
const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
<< node_type.node->op() << " node " << node_type.node->name() << " "
<< (any_black ? "BLACK" : "ALLOW")
<< (any_deny ? "DENY" : "ALLOW")
<< " because at least one of its siblings is "
<< (any_black ? "BLACK" : "ALLOW");
if (any_black) {
<< (any_deny ? "DENY" : "ALLOW");
if (any_deny) {
allow_set->erase(node_type_idx);
black_set->insert(node_type_idx);
deny_set->insert(node_type_idx);
} else {
allow_set->insert(node_type_idx);
}

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
// Represents the four lists of ops: the allow list, gray list, black list, and
// Represents the four lists of ops: the allow list, infer list, deny list, and
// clear list. These lists determine which ops are converted to fp16/bf16
// (referred to as 'f16' for short) and which ops stay as fp32.
class AutoMixedPrecisionLists {
@ -36,13 +36,13 @@ class AutoMixedPrecisionLists {
virtual gtl::FlatSet<string> AllowList() = 0;
// Returns the set of ops that can run in f16 and are considered numerically-
// safe (for execution in f16), but which may be made unsafe by an upstream
// blacklist op.
virtual gtl::FlatSet<string> GrayList() = 0;
// denylist op.
virtual gtl::FlatSet<string> InferList() = 0;
// Returns the set of ops that are considered numerically-dangerous (i.e.,
// unsafe for execution in f16) and whose effects may also be observed in
// downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to
// the Exp).
virtual gtl::FlatSet<string> BlackList() = 0;
virtual gtl::FlatSet<string> DenyList() = 0;
// Returns the set of ops that do not have numerically-significant effects
// (i.e., they are always considered safe for execution in f16 precision), and
// can run in f16.
@ -51,10 +51,11 @@ class AutoMixedPrecisionLists {
protected:
// Adds or removes ops from list if certain environmental variables are set.
static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
CHECK(list_name == "ALLOWLIST" || list_name == "GRAYLIST" || // Crash OK.
list_name == "BLACKLIST" || list_name == "CLEARLIST" ||
CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" || // Crash OK.
list_name == "DENYLIST" || list_name == "CLEARLIST" ||
// TODO(reedwm): for bkwds compat; remove when no longer necessary:
list_name == "WHITELIST");
list_name == "WHITELIST" || list_name == "GRAYLIST" ||
list_name == "BLACKLIST");
string add_env_var =
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
string remove_env_var =
@ -154,7 +155,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
return list;
}
gtl::FlatSet<string> GrayList() override {
gtl::FlatSet<string> InferList() override {
if (IsPseudoFastMath()) {
return gtl::FlatSet<string>{};
}
@ -204,11 +205,14 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
"Tanh",
"TanhGrad",
};
UpdateList("INFERLIST", &list);
// For backwards compatibility, keeping the original env variable here.
// TODO(reedwm): This should be removed if we don't have active users.
UpdateList("GRAYLIST", &list);
return list;
}
gtl::FlatSet<string> BlackList() override {
gtl::FlatSet<string> DenyList() override {
if (IsPseudoFastMath()) {
return gtl::FlatSet<string>{};
}
@ -224,6 +228,9 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
"SparseSoftmaxCrossEntropyWithLogits",
"Sum",
};
UpdateList("DENYLIST", &list);
// For backwards compatibility, keeping the original env variable here.
// TODO(reedwm): This should be removed if we don't have active users.
UpdateList("BLACKLIST", &list);
return list;
}
@ -344,7 +351,7 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
AutoMixedPrecisionListsMkl() {}
// Only ops which are supported by MKL in bfloat16 should be added to the
// allow list, gray list, or clear list.
// allow list, infer list, or clear list.
gtl::FlatSet<string> AllowList() override {
auto list = gtl::FlatSet<string>{"Conv2D",
"Conv2DBackpropFilter",
@ -360,10 +367,13 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
"BatchMatMulV2"};
UpdateList("ALLOWLIST", &list);
// For backwards compatibility, keeping the original env variable here.
// TODO(reedwm): This should be removed if we don't have active users.
UpdateList("WHITELIST", &list);
return list;
}
gtl::FlatSet<string> GrayList() override {
gtl::FlatSet<string> InferList() override {
auto list = gtl::FlatSet<string>{
"Add",
"AddN",
@ -384,11 +394,14 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
"Mul",
"Sub",
};
UpdateList("INFERLIST", &list);
// For backwards compatibility, keeping the original env variable here.
// TODO(reedwm): This should be removed if we don't have active users.
UpdateList("GRAYLIST", &list);
return list;
}
gtl::FlatSet<string> BlackList() override {
gtl::FlatSet<string> DenyList() override {
auto list = gtl::FlatSet<string>{
"Exp",
"Expm1",
@ -401,6 +414,9 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
"SparseSoftmaxCrossEntropyWithLogits",
"Sum",
};
UpdateList("DENYLIST", &list);
// For backwards compatibility, keeping the original env variable here.
// TODO(reedwm): This should be removed if we don't have active users.
UpdateList("BLACKLIST", &list);
return list;
}

View File

@ -160,7 +160,7 @@ class AutoMixedPrecisionTest : public GrapplerTest {
return AddNode(name, op, inputs, attributes, graph);
}
void TestSimpleUnaryGrayOp(
void TestSimpleUnaryInferOp(
double input_min, double input_max, double atol, double rtol,
const std::function<Output(const tensorflow::Scope&, Output)>&
test_op_factory) {
@ -170,8 +170,8 @@ class AutoMixedPrecisionTest : public GrapplerTest {
GenerateIdentityMatrix<DT_FLOAT>(size, size));
Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
Output gry1 = test_op_factory(s.WithOpName("gry1"), allow1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, eye);
Output infer1 = test_op_factory(s.WithOpName("infer1"), allow1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, eye);
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
GrapplerItem item;
item.fetch = {"fetch1"};
@ -191,7 +191,7 @@ class AutoMixedPrecisionTest : public GrapplerTest {
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
DT_FLOAT);
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
auto tensors = EvaluateNodes(output, item.fetch, feed);
@ -209,10 +209,10 @@ class AutoMixedPrecisionTest : public GrapplerTest {
TEST_F(AutoMixedPrecisionTest, NoOp) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
GrapplerItem item;
@ -230,9 +230,9 @@ TEST_F(AutoMixedPrecisionTest, NoOp) {
GraphView output_view(&output);
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
auto tensors = EvaluateNodes(output, item.fetch);
@ -284,16 +284,16 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
TEST_F(AutoMixedPrecisionTest, Simple) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
Output gry2 = ops::Log(s.WithOpName("gry2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), gry2);
Output blk2 = ops::SparseMatMul(s.WithOpName("blk2"), clr4, clr4);
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk2);
Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
GrapplerItem item;
@ -310,16 +310,16 @@ TEST_F(AutoMixedPrecisionTest, Simple) {
GraphView output_view(&output);
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("gry2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("Ta").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("Tb").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
auto tensors = EvaluateNodes(output, item.fetch);
@ -374,13 +374,13 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
Output blk1 = ops::Exp(s.WithOpName("blk1"), gry1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), blk1);
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
Output deny1 = ops::Exp(s.WithOpName("deny1"), infer1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), deny1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
Output blk2 = ops::Exp(s.WithOpName("blk2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
Output deny2 = ops::Exp(s.WithOpName("deny2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
GrapplerItem item;
item.fetch = {"allow1", "clr2", "clr3"};
@ -398,12 +398,12 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
auto tensors = EvaluateNodes(output, item.fetch);
@ -419,11 +419,11 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1);
Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
Output allow2 =
ops::MatMul(s.WithOpName("allow2").WithDevice(
"/job:localhost/replica:0/task:0/device:CPU:0"),
gry1, gry1);
infer1, infer1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
@ -443,7 +443,7 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
@ -521,9 +521,9 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
.x_backprop;
Output gry1 = ops::Add(s.WithOpName("gry1"), fbn1, bng1);
Output infer1 = ops::Add(s.WithOpName("infer1"), fbn1, bng1);
Output allow2 =
ops::Conv2D(s.WithOpName("allow2"), gry1, weight, {1, 1, 1, 1}, "SAME",
ops::Conv2D(s.WithOpName("allow2"), infer1, weight, {1, 1, 1, 1}, "SAME",
ops::Conv2D::DataFormat("NHWC"));
Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
@ -547,7 +547,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
EXPECT_EQ(output_view.GetNode("bng1")->op(), "FusedBatchNormGradV2");
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
auto tensors = EvaluateNodes(output, item.fetch);
@ -563,10 +563,10 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
Output gry1 =
ops::AddN(s.WithOpName("gry1"),
Output infer1 =
ops::AddN(s.WithOpName("infer1"),
{clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
GrapplerItem item;
@ -587,7 +587,7 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
EXPECT_EQ(type, DT_HALF);
}
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
auto tensors = EvaluateNodes(output, item.fetch);
@ -633,17 +633,17 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) {
TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
Output ent1 =
ops::internal::Enter(s.WithOpName("ent1"), blk1, "loop1").output;
ops::internal::Enter(s.WithOpName("ent1"), deny1, "loop1").output;
// Note that the second input is later replaced with "nxt1".
Output mrg1 = ops::Merge(s.WithOpName("mrg1"), {ent1, ent1}).output;
// For simplicity, the loop condition is constant false.
Output con1 = ops::Const(s.WithOpName("con1"), false, {});
Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), swt1.output_true);
Output allow1 = ops::MatMul(s.WithOpName("allow1"), gry1, gry1);
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), swt1.output_true);
Output allow1 = ops::MatMul(s.WithOpName("allow1"), infer1, infer1);
Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
@ -671,14 +671,14 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
GraphView output_view(&output);
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
// Note that mrg1 gets painted black because it is between blk1 and gry1. This
// forces nxt1 and mrg2 to be painted black as well (they would otherwise be
// painted allow because they are clear and have a direct path to allow1).
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
// Note that mrg1 gets painted deny because it is between deny1 and infer1.
// This forces nxt1 and mrg2 to be painted deny as well (they would otherwise
// be painted allow because they are clear and have a direct path to allow1).
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
@ -711,8 +711,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
shape, DT_FLOAT)
.item;
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
auto tl1w3 =
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
Output tl1r2 =
@ -748,7 +748,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
@ -776,8 +776,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
tl1w2.output_handle, shape, DT_FLOAT)
.tensor;
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
auto tl1w3 =
ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
@ -811,7 +811,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
@ -835,8 +835,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
shape, DT_FLOAT)
.tensor;
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
// This tests that a allow-painted object node (tl2) will force an unpainted
@ -863,7 +863,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
@ -902,8 +902,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
Output tl3r1 =
ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
.tensor;
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl3r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl3r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
GrapplerItem item;
@ -922,7 +922,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
const char* type_key = "element_dtype";
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
@ -967,22 +967,25 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
tensorflow::Input shape = {32, 32};
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1);
Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
auto tl1w1 = ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, gry1);
auto _gry1 = tensorflow::ops::AsNodeOut(s, gry1);
auto tl1w1 =
ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, infer1);
auto _infer1 = tensorflow::ops::AsNodeOut(s, infer1);
auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
auto builder =
tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
tensorflow::Node* func1_op;
TF_CHECK_OK(
builder.Input(_tl1w1_handle).Input(_gry1).Finalize(s.graph(), &func1_op));
TF_CHECK_OK(builder.Input(_tl1w1_handle)
.Input(_infer1)
.Finalize(s.graph(), &func1_op));
Output func1_handle(func1_op, 0);
Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
shape, DT_FLOAT)
.tensor;
auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
auto tl2w1 = ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, gry1);
auto tl2w1 =
ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, infer1);
Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
tl2w1.output_handle, shape, DT_FLOAT)
.tensor;
@ -1004,7 +1007,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
const char* type_key = "element_dtype";
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
@ -1069,7 +1072,7 @@ TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
}
TEST_F(AutoMixedPrecisionTest, EluOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-5, 5, 1.0e-3, 1.0e-3,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Elu(scope, input);
@ -1077,7 +1080,7 @@ TEST_F(AutoMixedPrecisionTest, EluOp) {
}
TEST_F(AutoMixedPrecisionTest, ErfOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-5, 5, 1.0e-3, -1,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Erf(scope, input);
@ -1085,7 +1088,7 @@ TEST_F(AutoMixedPrecisionTest, ErfOp) {
}
TEST_F(AutoMixedPrecisionTest, ErfcOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-5, 5, 1.0e-3, -1,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Erfc(scope, input);
@ -1093,7 +1096,7 @@ TEST_F(AutoMixedPrecisionTest, ErfcOp) {
}
TEST_F(AutoMixedPrecisionTest, InvOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
0.01, 10, -1, 1.0e-3,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Inv(scope, input);
@ -1101,7 +1104,7 @@ TEST_F(AutoMixedPrecisionTest, InvOp) {
}
TEST_F(AutoMixedPrecisionTest, LogOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
0.01, 10, 1.0e-3, 2.0e-3,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Log(scope, input);
@ -1109,7 +1112,7 @@ TEST_F(AutoMixedPrecisionTest, LogOp) {
}
TEST_F(AutoMixedPrecisionTest, Log1pOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-0.99, 9, 1.0e-3, 5.0e-3,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Log1p(scope, input);
@ -1117,7 +1120,7 @@ TEST_F(AutoMixedPrecisionTest, Log1pOp) {
}
TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-8, 8, -1, 1.0e-2,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::LogSoftmax(scope, input);
@ -1125,7 +1128,7 @@ TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
}
TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
0.01, 10, -1, 1.0e-3,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Reciprocal(scope, input);
@ -1133,7 +1136,7 @@ TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
}
TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-5, 5, 1.0e-3, -1,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Sigmoid(scope, input);
@ -1141,7 +1144,7 @@ TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
}
TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-8, 8, 2.0e-3, -1,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Softmax(scope, input);
@ -1149,7 +1152,7 @@ TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
}
TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-5, 5, 1.0e-3, 1.0e-3,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Softplus(scope, input);
@ -1157,7 +1160,7 @@ TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
}
TEST_F(AutoMixedPrecisionTest, SqrtOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
0, 10, 1.0e-3, 1.0e-3,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Sqrt(scope, input);
@ -1165,7 +1168,7 @@ TEST_F(AutoMixedPrecisionTest, SqrtOp) {
}
TEST_F(AutoMixedPrecisionTest, TanhOp) {
TestSimpleUnaryGrayOp(
TestSimpleUnaryInferOp(
-5, 5, 1.0e-3, -1,
[](const tensorflow::Scope& scope, Output input) -> Output {
return ops::Tanh(scope, input);
@ -1229,16 +1232,16 @@ TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
TEST_F(AutoMixedPrecisionMklTest, Simple) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk3);
Output deny2 = ops::Log(s.WithOpName("deny2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
Output deny3 = ops::SparseMatMul(s.WithOpName("deny3"), clr4, clr4);
Output clr5 = ops::Relu(s.WithOpName("clr5"), deny3);
Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
GrapplerItem item;
@ -1255,16 +1258,16 @@ TEST_F(AutoMixedPrecisionMklTest, Simple) {
GraphView output_view(&output);
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Ta").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Tb").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Ta").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Tb").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
auto tensors = EvaluateNodes(output, item.fetch);
@ -1294,8 +1297,8 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
shape, DT_FLOAT)
.item;
Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
Output infer1 = ops::Mul(s.WithOpName("infer1"), tl1r1, tl1r1);
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
auto tl1w3 =
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
Output tl1r2 =
@ -1335,7 +1338,7 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
DT_BFLOAT16);

View File

@ -36,8 +36,8 @@ namespace internal {
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
// All the nodes that should be blacklisted and not swapped.
bool IsBlacklisted(const NodeDef& node) {
// All the nodes that should be denylisted and not swapped.
bool IsDenylisted(const NodeDef& node) {
return
// Collective ops should not be swapped.
IsCollective(node) ||
@ -94,8 +94,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
bool* is_candidate) {
*is_candidate = false;
// Make sure we are not a blacklisted op.
if (IsBlacklisted(node)) {
// Make sure we are not a denylisted op.
if (IsDenylisted(node)) {
return Status::OK();
}
@ -215,7 +215,7 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
// Checks if a node is a candidate to pin to Host.
// The rough algorithm is as follows:
// 1] Check if node is blacklisted.
// 1] Check if node is denylisted.
// 2] Check if node can run on Host.
// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
// ints, and pinned to Host).
@ -230,7 +230,7 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
}
// Skip these node types.
if (IsBlacklisted(node)) {
if (IsDenylisted(node)) {
return Status::OK();
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_
#include <stddef.h>
#include <cstddef>
#include <functional>
#include <memory>
@ -176,6 +177,61 @@ class BasicBatchScheduler : public BatchScheduler<TaskType> {
// parameter.
int max_enqueued_batches = 10;
// If true, an input task (i.e., input of `BasicBatchScheduler::Schedule`)
// with a large size (i.e., larger than the largest value of
// `allowed_batch_sizes`) will be split into multiple smaller batch tasks
// and possibly put into different batches for processing. If false, each
// input task is put into one batch as a whole for processing.
//
// API note:
// The value of this option doesn't affect processing output given the same
// input; it affects implementation details as stated below:
// 1. Improve batching efficiency by eliminating unnecessary padding in the
// following scenario: when an open batch has M slots while an input of size
// N is scheduled (M < N), the input can be split to fill remaining slots
// of an open batch as opposed to padding.
// 2.`max_batch_size` specifies the limit of input and
// `max_execution_batch_size` specifies the limit of a task to be processed.
// API user can give an input of size 128 when 'max_execution_batch_size'
// is 32 -> implementation can split input of 128 into 4 x 32, schedule
// concurrent processing, and then return concatenated results corresponding
// to 128.
bool enable_large_batch_splitting = false;
// `split_input_task_func` specifies how to split `input_task` into
// `output_tasks`.
//
// `input_task`: a unit of task to be split.
// `first_output_task_size`: task size of first output.
// `max_batch_size`: Maximum size of each batch.
// `output_tasks`: A list of output tasks after split.
//
// REQUIRED:
// 1) All `output_tasks` should be non-empty tasks.
// 2) Sizes of `output_tasks` add up to size of `input_task`.
//
// NOTE:
// Instantiations of `TaskType` may vary, so it's up to caller to define
// how (e.g., which members to access) to split input tasks.
std::function<Status(std::unique_ptr<TaskType>* input_task,
int first_output_task_size, int input_batch_size_limit,
std::vector<std::unique_ptr<TaskType>>* output_tasks)>
split_input_task_func;
// The maximum size of each enqueued batch (i.e., in `batches_`).
//
// The scheduler may form batches of any size between 1 and this number
// (inclusive). If there is a need to quantize the batch sizes, i.e. only
// submit batches whose size is in a small set of allowed sizes, that can be
// done by adding padding in the process-batch callback.
//
// REQUIRES:
// - If enable_large_batch_splitting is true, `max_execution_batch_size` is
// less than or equal to `max_batch_size`.
// - If enable_large_batch_splitting is false, `max_execution_batch_size` is
// equal to `max_batch_size`.
int max_execution_batch_size = 10;
// The following options are typically only overridden by test code.
// The environment to use.
@ -231,6 +287,12 @@ Status BasicBatchScheduler<TaskType>::Create(
options.batch_timeout_micros;
shared_scheduler_queue_options.max_enqueued_batches =
options.max_enqueued_batches;
shared_scheduler_queue_options.enable_large_batch_splitting =
options.enable_large_batch_splitting;
shared_scheduler_queue_options.split_input_task_func =
options.split_input_task_func;
shared_scheduler_queue_options.max_execution_batch_size =
options.max_execution_batch_size;
std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue;
TF_RETURN_IF_ERROR(shared_scheduler->AddQueue(shared_scheduler_queue_options,
process_batch_callback,

Some files were not shown because too many files have changed in this diff Show More