Merge branch 'master' of https://github.com/tensorflow/tensorflow into TF_GetName
This commit is contained in:
commit
7a113f40d9
6
.bazelrc
6
.bazelrc
@ -84,7 +84,8 @@
|
||||
# release_gpu_common: Common options for GPU builds on Linux and Windows.
|
||||
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
|
||||
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
|
||||
# release_gpu_linux: Toolchain and CUDA options for Linux PU builds.
|
||||
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
|
||||
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
|
||||
|
||||
# Allow builds using libc++ as a linker library
|
||||
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
|
||||
@ -570,3 +571,6 @@ build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
|
||||
build:release_gpu_linux --config=release_gpu_common
|
||||
build:release_gpu_linux --config=avx_linux
|
||||
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
|
||||
|
||||
build:release_cpu_windows --config=release_common
|
||||
build:release_cpu_windows --announce_rc
|
||||
|
10
RELEASE.md
10
RELEASE.md
@ -38,6 +38,7 @@
|
||||
* Calling ops with a python constants or numpy values is now consistent with
|
||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
||||
truncating inputs such as from int64 to int32.
|
||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
|
||||
* `tf.data`:
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
@ -53,7 +54,8 @@
|
||||
* `tf.function`/AutoGraph:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.lite`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Better support for ops with high-dimensional broadcasting inputs by adding
|
||||
`BroadcastTo` ops when necessary.
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Math and Linear Algebra:
|
||||
@ -65,9 +67,9 @@
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" with "allowlist" where possible.
|
||||
Please see https://developers.google.com/style/word-list#blacklist for more
|
||||
context.
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more context.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
@ -532,16 +532,14 @@ selects.config_setting_group(
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/...",
|
||||
"//learning/brain/swift/x10/...",
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -240,6 +240,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
@ -308,6 +310,8 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/util:abstract_stack_trace",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
|
||||
using tensorflow::dyn_cast;
|
||||
using tensorflow::string;
|
||||
using tensorflow::gtl::ArraySlice;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tracing {
|
||||
@ -138,20 +139,23 @@ class GraphOperation : public TracingOperation {
|
||||
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrString has not been implemented yet.");
|
||||
tensorflow::StringPiece s(data, length);
|
||||
op_->node_builder.Attr(attr_name, s);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrInt has not been implemented yet.");
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
op_->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrFloat(const char* attr_name, float value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFloat has not been implemented yet.");
|
||||
op_->node_builder.Attr(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrBool(const char* attr_name, bool value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrBool has not been implemented yet.");
|
||||
op_->node_builder.Attr(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrType(const char* const attr_name, DataType value) override {
|
||||
if (!op_) {
|
||||
@ -164,8 +168,15 @@ class GraphOperation : public TracingOperation {
|
||||
}
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrShape has not been implemented yet.");
|
||||
PartialTensorShape shape;
|
||||
if (num_dims >= 0) {
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
|
||||
reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
|
||||
}
|
||||
op_->node_builder.Attr(attr_name, shape);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) override {
|
||||
@ -174,8 +185,10 @@ class GraphOperation : public TracingOperation {
|
||||
}
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionName has not been implemented yet.");
|
||||
tensorflow::NameAttrList func_name;
|
||||
func_name.set_name(string(value, value + length));
|
||||
op_->node_builder.Attr(attr_name, func_name);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) override {
|
||||
@ -184,33 +197,71 @@ class GraphOperation : public TracingOperation {
|
||||
}
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrStringList has not been implemented yet.");
|
||||
if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
|
||||
op_->colocation_constraints.clear();
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
|
||||
lengths[i]);
|
||||
}
|
||||
} else {
|
||||
std::vector<tensorflow::StringPiece> v;
|
||||
v.reserve(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
op_->node_builder.Attr(attr_name, v);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFloatList has not been implemented yet.");
|
||||
op_->node_builder.Attr(attr_name,
|
||||
ArraySlice<const float>(values, num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrIntList has not been implemented yet.");
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
op_->node_builder.Attr(
|
||||
attr_name,
|
||||
ArraySlice<const tensorflow::int64>(
|
||||
reinterpret_cast<const tensorflow::int64*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTypeList has not been implemented yet.");
|
||||
op_->node_builder.Attr(attr_name,
|
||||
ArraySlice<const DataType>(values, num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrBoolList has not been implemented yet.");
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
op_->node_builder.Attr(attr_name,
|
||||
ArraySlice<const bool>(b.get(), num_values));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrShapeList has not been implemented yet.");
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
if (num_dims[i] < 0) {
|
||||
shapes.emplace_back();
|
||||
} else {
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
shapes.emplace_back(ArraySlice<tensorflow::int64>(
|
||||
reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
|
||||
}
|
||||
}
|
||||
op_->node_builder.Attr(attr_name, shapes);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
@ -42,55 +44,10 @@ class CppGradients
|
||||
}
|
||||
};
|
||||
|
||||
// Creates an Identity op.
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
if (isa<tracing::TracingOperation>(identity_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// =================== Register gradients for Add ============================
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs), "Id0"));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs), "Id1"));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction(op.ctx);
|
||||
}
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
return registry->Register("Add", AddRegisterer);
|
||||
}
|
||||
|
||||
// =================== End gradient registrations ============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
@ -26,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/util/abstract_stack_trace.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
@ -44,6 +46,12 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
// Set stack trace to be used for potential async error reporting.
|
||||
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
|
||||
|
||||
// Returns the stack trace set by `SetStackTrace` if exists.
|
||||
virtual absl::optional<AbstractStackTrace> GetStackTrace() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
|
@ -25,7 +25,9 @@ cc_library(
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
":expiring_lru_cache",
|
||||
":gcs_helper",
|
||||
":ram_file_block_cache",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
|
@ -28,6 +28,27 @@ limitations under the License.
|
||||
// This filesystem will support `gs://` URI schemes.
|
||||
namespace gcs = google::cloud::storage;
|
||||
|
||||
// The environment variable that overrides the block size for aligned reads from
|
||||
// GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes).
|
||||
constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB";
|
||||
constexpr size_t kDefaultBlockSize = 64 * 1024 * 1024;
|
||||
// The environment variable that overrides the max size of the LRU cache of
|
||||
// blocks read from GCS. Specified in MB.
|
||||
constexpr char kMaxCacheSize[] = "GCS_READ_CACHE_MAX_SIZE_MB";
|
||||
constexpr size_t kDefaultMaxCacheSize = 0;
|
||||
// The environment variable that overrides the maximum staleness of cached file
|
||||
// contents. Once any block of a file reaches this staleness, all cached blocks
|
||||
// will be evicted on the next read.
|
||||
constexpr char kMaxStaleness[] = "GCS_READ_CACHE_MAX_STALENESS";
|
||||
constexpr uint64_t kDefaultMaxStaleness = 0;
|
||||
|
||||
constexpr char kStatCacheMaxAge[] = "GCS_STAT_CACHE_MAX_AGE";
|
||||
constexpr uint64_t kStatCacheDefaultMaxAge = 5;
|
||||
// The environment variable that overrides the maximum number of entries in the
|
||||
// Stat cache.
|
||||
constexpr char kStatCacheMaxEntries[] = "GCS_STAT_CACHE_MAX_ENTRIES";
|
||||
constexpr size_t kStatCacheDefaultMaxEntries = 1024;
|
||||
|
||||
// How to upload new data when Flush() is called multiple times.
|
||||
// By default the entire file is reuploaded.
|
||||
constexpr char kAppendMode[] = "GCS_APPEND_MODE";
|
||||
@ -82,28 +103,15 @@ static void MaybeAppendSlash(std::string* name) {
|
||||
name->push_back('/');
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
typedef struct GCSFile {
|
||||
const std::string bucket;
|
||||
const std::string object;
|
||||
gcs::Client* gcs_client; // not owned
|
||||
} GCSFile;
|
||||
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Adding cache.
|
||||
// `google-cloud-cpp` is working on a feature that we may want to use.
|
||||
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
auto stream = gcs_file->gcs_client->ReadObject(
|
||||
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
|
||||
// A helper function to actually read the data from GCS.
|
||||
static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
|
||||
size_t buffer_size, char* buffer,
|
||||
gcs::Client* gcs_client, TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
auto stream = gcs_client->ReadObject(
|
||||
bucket, object, gcs::ReadRange(offset, offset + buffer_size));
|
||||
TF_SetStatusFromGCSStatus(stream.status(), status);
|
||||
if ((TF_GetCode(status) != TF_OK) &&
|
||||
(TF_GetCode(status) != TF_OUT_OF_RANGE)) {
|
||||
@ -115,11 +123,92 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
|
||||
return -1;
|
||||
}
|
||||
if (read != n) {
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
}
|
||||
stream.read(buffer, read);
|
||||
return read;
|
||||
return stream.gcount();
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
using ReadFn =
|
||||
std::function<int64_t(const std::string& path, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status)>;
|
||||
typedef struct GCSFile {
|
||||
const std::string path;
|
||||
const bool is_cache_enable;
|
||||
const uint64_t buffer_size;
|
||||
ReadFn read_fn;
|
||||
absl::Mutex buffer_mutex;
|
||||
uint64_t buffer_start ABSL_GUARDED_BY(buffer_mutex);
|
||||
bool buffer_end_is_past_eof ABSL_GUARDED_BY(buffer_mutex);
|
||||
std::string buffer ABSL_GUARDED_BY(buffer_mutex);
|
||||
|
||||
GCSFile(std::string path, bool is_cache_enable, uint64_t buffer_size,
|
||||
ReadFn read_fn)
|
||||
: path(path),
|
||||
is_cache_enable(is_cache_enable),
|
||||
buffer_size(buffer_size),
|
||||
read_fn(std::move(read_fn)),
|
||||
buffer_mutex(),
|
||||
buffer_start(0),
|
||||
buffer_end_is_past_eof(false),
|
||||
buffer() {}
|
||||
} GCSFile;
|
||||
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// `google-cloud-cpp` is working on a feature that we may want to use.
|
||||
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
if (gcs_file->is_cache_enable || n > gcs_file->buffer_size) {
|
||||
return gcs_file->read_fn(gcs_file->path, offset, n, buffer, status);
|
||||
} else {
|
||||
absl::MutexLock l(&gcs_file->buffer_mutex);
|
||||
size_t buffer_end = gcs_file->buffer_start + gcs_file->buffer.size();
|
||||
size_t copy_size = 0;
|
||||
if (offset < buffer_end && gcs_file->buffer_start) {
|
||||
copy_size = (std::min)(n, static_cast<size_t>(buffer_end - offset));
|
||||
memcpy(buffer,
|
||||
gcs_file->buffer.data() + (offset - gcs_file->buffer_start),
|
||||
copy_size);
|
||||
}
|
||||
bool consumed_buffer_to_eof =
|
||||
offset + copy_size >= buffer_end && gcs_file->buffer_end_is_past_eof;
|
||||
if (copy_size < n && !consumed_buffer_to_eof) {
|
||||
gcs_file->buffer_start = offset + copy_size;
|
||||
gcs_file->buffer.resize(gcs_file->buffer_size);
|
||||
auto read_fill_buffer = gcs_file->read_fn(
|
||||
gcs_file->path, gcs_file->buffer_start, gcs_file->buffer_size,
|
||||
&(gcs_file->buffer[0]), status);
|
||||
gcs_file->buffer_end_is_past_eof =
|
||||
(TF_GetCode(status) == TF_OUT_OF_RANGE);
|
||||
if (read_fill_buffer >= 0) gcs_file->buffer.resize(read_fill_buffer);
|
||||
if (TF_GetCode(status) != TF_OK &&
|
||||
TF_GetCode(status) != TF_OUT_OF_RANGE) {
|
||||
// Empty the buffer to avoid caching bad reads.
|
||||
gcs_file->buffer.resize(0);
|
||||
return -1;
|
||||
}
|
||||
size_t remaining_copy =
|
||||
(std::min)(n - copy_size, gcs_file->buffer.size());
|
||||
memcpy(buffer + copy_size, gcs_file->buffer.data(), remaining_copy);
|
||||
copy_size += remaining_copy;
|
||||
if (copy_size < n) {
|
||||
// Forget the end-of-file flag to allow for clients that poll on the
|
||||
// same file.
|
||||
gcs_file->buffer_end_is_past_eof = false;
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
return copy_size;
|
||||
}
|
||||
}
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return copy_size;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
@ -290,11 +379,53 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region) {
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_gcs_filesystem {
|
||||
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
|
||||
// TODO(vnvo2409): Use partial reponse for better performance.
|
||||
// TODO(vnvo2409): We could do some cleanups like `return TF_SetStatus`.
|
||||
// TODO(vnvo2409): Refactor the filesystem implementation when
|
||||
// https://github.com/googleapis/google-cloud-cpp/issues/4482 is done.
|
||||
GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
|
||||
: gcs_client(gcs_client), block_cache_lock() {
|
||||
const char* append_mode = std::getenv(kAppendMode);
|
||||
compose = (append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
|
||||
|
||||
uint64_t value;
|
||||
block_size = kDefaultBlockSize;
|
||||
size_t max_bytes = kDefaultMaxCacheSize;
|
||||
uint64_t max_staleness = kDefaultMaxStaleness;
|
||||
|
||||
// Apply the overrides for the block size (MB), max bytes (MB), and max
|
||||
// staleness (seconds) if provided.
|
||||
if (absl::SimpleAtoi(std::getenv(kBlockSize), &value)) {
|
||||
block_size = value * 1024 * 1024;
|
||||
}
|
||||
if (absl::SimpleAtoi(std::getenv(kMaxCacheSize), &value)) {
|
||||
max_bytes = static_cast<size_t>(value * 1024 * 1024);
|
||||
}
|
||||
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
|
||||
max_staleness = value;
|
||||
}
|
||||
|
||||
auto gcs_client_ptr = &this->gcs_client;
|
||||
file_block_cache = std::make_unique<RamFileBlockCache>(
|
||||
block_size, max_bytes, max_staleness,
|
||||
[gcs_client_ptr](const std::string& filename, size_t offset,
|
||||
size_t buffer_size, char* buffer, TF_Status* status) {
|
||||
return LoadBufferFromGCS(filename, offset, buffer_size, buffer,
|
||||
gcs_client_ptr, status);
|
||||
});
|
||||
|
||||
uint64_t stat_cache_max_age = kStatCacheDefaultMaxAge;
|
||||
size_t stat_cache_max_entries = kStatCacheDefaultMaxEntries;
|
||||
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxAge), &value)) {
|
||||
stat_cache_max_age = value;
|
||||
}
|
||||
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxEntries), &value)) {
|
||||
stat_cache_max_entries = static_cast<size_t>(value);
|
||||
}
|
||||
stat_cache = std::make_unique<ExpiringLRUCache<GcsFileStat>>(
|
||||
stat_cache_max_age, stat_cache_max_entries);
|
||||
}
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
@ -303,12 +434,7 @@ void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
return;
|
||||
}
|
||||
|
||||
const char* append_mode = std::getenv(kAppendMode);
|
||||
bool compose =
|
||||
(append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
|
||||
|
||||
filesystem->plugin_filesystem =
|
||||
new GCSFile({std::move(client.value()), compose});
|
||||
filesystem->plugin_filesystem = new GCSFile(std::move(client.value()));
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
@ -325,8 +451,32 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
bool is_cache_enabled;
|
||||
{
|
||||
absl::MutexLock l(&gcs_file->block_cache_lock);
|
||||
is_cache_enabled = gcs_file->file_block_cache->IsCacheEnabled();
|
||||
}
|
||||
auto read_fn = [gcs_file, is_cache_enabled](
|
||||
const std::string& path, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) -> int64_t {
|
||||
// TODO(vnvo2409): Check for `stat_cache`.
|
||||
int64_t read = 0;
|
||||
if (is_cache_enabled) {
|
||||
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
|
||||
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
|
||||
} else {
|
||||
read = LoadBufferFromGCS(path, offset, n, buffer, &gcs_file->gcs_client,
|
||||
status);
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (read < n)
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return read;
|
||||
};
|
||||
file->plugin_file = new tf_random_access_file::GCSFile(
|
||||
{std::move(bucket), std::move(object), &gcs_file->gcs_client});
|
||||
std::move(path), is_cache_enabled, gcs_file->block_size, read_fn);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
|
||||
@ -45,10 +47,23 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region);
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
typedef struct GcsFileStat {
|
||||
TF_FileStatistics base;
|
||||
int64_t generation_number;
|
||||
} GcsFileStat;
|
||||
|
||||
typedef struct GCSFile {
|
||||
google::cloud::storage::Client gcs_client; // owned
|
||||
bool compose;
|
||||
absl::Mutex block_cache_lock;
|
||||
std::shared_ptr<RamFileBlockCache> file_block_cache
|
||||
ABSL_GUARDED_BY(block_cache_lock);
|
||||
uint64_t block_size; // Reads smaller than block_size will trigger a read
|
||||
// of block_size.
|
||||
std::unique_ptr<ExpiringLRUCache<GcsFileStat>> stat_cache;
|
||||
GCSFile(google::cloud::storage::Client&& gcs_client);
|
||||
} GCSFile;
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include <aws/core/config/AWSProfileConfigLoader.h>
|
||||
#include <aws/core/utils/FileSystemUtils.h>
|
||||
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
|
||||
#include <aws/s3/model/CopyObjectRequest.h>
|
||||
#include <aws/s3/model/GetObjectRequest.h>
|
||||
#include <aws/s3/model/HeadBucketRequest.h>
|
||||
#include <aws/s3/model/HeadObjectRequest.h>
|
||||
@ -545,7 +546,10 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
GetS3Client(s3_file);
|
||||
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
|
||||
|
||||
// We need to delete `file->plugin_file` in case of errors.
|
||||
// We need to delete `file->plugin_file` in case of errors. We set
|
||||
// `file->plugin_file` to `nullptr` in order to avoid segment fault when
|
||||
// calling deleter of `unique_ptr`.
|
||||
file->plugin_file = nullptr;
|
||||
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile*)> writer(
|
||||
file, [](TF_WritableFile* file) {
|
||||
if (file != nullptr && file->plugin_file != nullptr) {
|
||||
@ -561,10 +565,14 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
|
||||
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
|
||||
if (file != nullptr) {
|
||||
tf_random_access_file::Cleanup(file);
|
||||
if (file->plugin_file != nullptr)
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
// We set `reader->plugin_file` to `nullptr` in order to avoid segment fault
|
||||
// when calling deleter of `unique_ptr`
|
||||
reader->plugin_file = nullptr;
|
||||
NewRandomAccessFile(filesystem, path, reader.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
@ -678,7 +686,7 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
TF_ReadOnlyMemoryRegion* region,
|
||||
TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, true, &bucket, &object, status);
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
@ -695,10 +703,14 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
|
||||
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
|
||||
if (file != nullptr) {
|
||||
tf_random_access_file::Cleanup(file);
|
||||
if (file->plugin_file != nullptr)
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
// We set `reader->plugin_file` to `nullptr` in order to avoid segment fault
|
||||
// when calling deleter of `unique_ptr`
|
||||
reader->plugin_file = nullptr;
|
||||
NewRandomAccessFile(filesystem, path, reader.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
auto read =
|
||||
@ -710,6 +722,67 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
static void SimpleCopyFile(const Aws::String& source,
|
||||
const Aws::String& bucket_dst,
|
||||
const Aws::String& object_dst, S3File* s3_file,
|
||||
TF_Status* status) {
|
||||
Aws::S3::Model::CopyObjectRequest copy_object_request;
|
||||
copy_object_request.WithCopySource(source)
|
||||
.WithBucket(bucket_dst)
|
||||
.WithKey(object_dst);
|
||||
auto copy_object_outcome =
|
||||
s3_file->s3_client->CopyObject(copy_object_request);
|
||||
if (!copy_object_outcome.IsSuccess())
|
||||
TF_SetStatusFromAWSError(copy_object_outcome.GetError(), status);
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
};
|
||||
|
||||
static void MultiPartCopy(const Aws::String& source,
|
||||
const Aws::String& bucket_dst,
|
||||
const Aws::String& object_dst, const size_t num_parts,
|
||||
const uint64_t file_size, S3File* s3_file,
|
||||
TF_Status* status){};
|
||||
|
||||
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
|
||||
TF_Status* status) {
|
||||
auto file_size = GetFileSize(filesystem, src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (file_size == 0)
|
||||
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Source is a directory or empty file");
|
||||
|
||||
Aws::String bucket_src, object_src;
|
||||
ParseS3Path(src, false, &bucket_src, &object_src, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
Aws::String copy_src = bucket_src + "/" + object_src;
|
||||
|
||||
Aws::String bucket_dst, object_dst;
|
||||
ParseS3Path(dst, false, &bucket_dst, &object_dst, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
auto chunk_size =
|
||||
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
|
||||
size_t num_parts = 1;
|
||||
if (file_size > chunk_size) num_parts = ceil((float)file_size / chunk_size);
|
||||
if (num_parts == 1)
|
||||
SimpleCopyFile(copy_src, bucket_dst, object_dst, s3_file, status);
|
||||
else if (num_parts > 10000)
|
||||
TF_SetStatus(
|
||||
status, TF_UNIMPLEMENTED,
|
||||
absl::StrCat("MultiPartCopy with number of parts more than 10000 is "
|
||||
"not supported. Your object ",
|
||||
src, " required ", num_parts,
|
||||
" as multi_part_copy_part_size is set to ", chunk_size,
|
||||
". You can control this part size using the environment "
|
||||
"variable S3_MULTI_PART_COPY_PART_SIZE to increase it.")
|
||||
.c_str());
|
||||
else
|
||||
MultiPartCopy(copy_src, bucket_dst, object_dst, num_parts, file_size,
|
||||
s3_file, status);
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_s3_filesystem
|
||||
|
23
tensorflow/c/experimental/gradients/BUILD
Normal file
23
tensorflow/c/experimental/gradients/BUILD
Normal file
@ -0,0 +1,23 @@
|
||||
# Library of gradient functions.
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "math_grad",
|
||||
srcs = ["math_grad.cc"],
|
||||
hdrs = [
|
||||
"math_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
54
tensorflow/c/experimental/gradients/math_grad.cc
Normal file
54
tensorflow/c/experimental/gradients/math_grad.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
|
||||
using tensorflow::ops::Identity;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace {
|
||||
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
// TODO(b/145674566): Handle name unification in tracing code.
|
||||
// TODO(b/161805092): Support broadcasting.
|
||||
TF_RETURN_IF_ERROR(ops::Identity(
|
||||
ctx_, {grad_inputs[0]}, absl::MakeSpan(identity_outputs), "Identity0"));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
TF_RETURN_IF_ERROR(ops::Identity(
|
||||
ctx_, {grad_inputs[0]}, absl::MakeSpan(identity_outputs), "Identity1"));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction(op.ctx);
|
||||
}
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -12,34 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "absl/random/random.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static int64 overridden_node_id = -1;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
|
||||
void OverrideNodeIdForTesting(const int64 node_id) {
|
||||
overridden_node_id = node_id;
|
||||
}
|
||||
|
||||
uint64 GetNodeId() {
|
||||
if (overridden_node_id > -1) {
|
||||
return overridden_node_id;
|
||||
} else {
|
||||
return absl::Uniform(absl::SharedBitGen(), uint64{0},
|
||||
std::numeric_limits<uint64>::max());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
namespace gradients {
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
24
tensorflow/c/experimental/ops/BUILD
Normal file
24
tensorflow/c/experimental/ops/BUILD
Normal file
@ -0,0 +1,24 @@
|
||||
# Experimental ops. These will eventually be replaced by machine-generated versions.
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "array_ops",
|
||||
srcs = [
|
||||
"array_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"array_ops.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
38
tensorflow/c/experimental/ops/array_ops.cc
Normal file
38
tensorflow/c/experimental/ops/array_ops.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
// Creates an Identity op.
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(identity_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
return identity_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
@ -12,27 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
|
||||
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Implementation details of distributed_tpu_rewrite_pass.cc, please DO NOT
|
||||
// depend on these.
|
||||
namespace internal {
|
||||
|
||||
// When set to a value >= 0, overrides the node_id. Used for getting
|
||||
// deterministic node_ids during testing.
|
||||
void OverrideNodeIdForTesting(int64 node_id);
|
||||
|
||||
// Retrieves the node id, used to make some node names unique in the rewrite
|
||||
// pass.
|
||||
uint64 GetNodeId();
|
||||
|
||||
} // namespace internal
|
||||
namespace ops {
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
|
@ -111,10 +111,6 @@ TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory(
|
||||
TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority(
|
||||
TF_KernelBuilder* kernel_builder, int32_t priority_number);
|
||||
|
||||
typedef struct string_view string_view;
|
||||
|
||||
TF_CAPI_EXPORT extern string_view TF_GetName(TF_KernelBuilder* kernel_builder);
|
||||
|
||||
// Register the given kernel builder with the TensorFlow runtime. If
|
||||
// registration fails, the given status will be populated.
|
||||
//
|
||||
|
@ -53,7 +53,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#include <iostream>
|
||||
struct MyCustomKernel {
|
||||
bool created;
|
||||
bool compute_called;
|
||||
@ -162,9 +161,6 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
ASSERT_TRUE(delete_called);
|
||||
}
|
||||
|
||||
TEST(TestKernel, TestGetKernelName) {
|
||||
}
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
|
@ -1829,7 +1829,7 @@ TEST(XlaCompilationTest, XLALiteAllowlist) {
|
||||
}
|
||||
EXPECT_TRUE(unknow_op.empty())
|
||||
<< "Someone added support for a new TF opeations inside XLA. They must "
|
||||
"be included in the XLALite allowlist or blacklist:\n"
|
||||
"be included in the XLALite allowlist or denylist:\n"
|
||||
<< absl::StrJoin(unknow_op, "\n");
|
||||
}
|
||||
} // namespace
|
||||
|
@ -74,7 +74,6 @@ We have several choices on how to lower the host-side part from LHLO:
|
||||
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
|
||||
TFRT ops are interpreted by C++ code.
|
||||
* (Con) host side is under development and not tested.
|
||||
* (Con) the JAX integration isn’t clear from a runtime point of view
|
||||
* Jitted CPU code
|
||||
* (Pro) great lower-ability. Create a few loops and conditions and it's
|
||||
done.
|
||||
@ -84,8 +83,7 @@ We have several choices on how to lower the host-side part from LHLO:
|
||||
dynamic loading, etc).
|
||||
* Existing (interpreting) XLA runtime
|
||||
|
||||
Tentative conclusion: Use jitted CPU code during the transition, and optionally
|
||||
adopt TFRT in the end.
|
||||
Decision: adopt TFRT, but also support jitting CPU code in TFRT.
|
||||
|
||||
## Migrating Device LLVM IR (Task 3)
|
||||
|
||||
@ -114,7 +112,7 @@ end state of each XLA op:
|
||||
* (Cost) Will be throw-away work if we want to ultimately migrate to
|
||||
Standard.
|
||||
* (Benefit) It is easy and mechanical. Can be done in a short period.
|
||||
* (Benefit) It doesn't benefit more compared to a).
|
||||
* (Benefit) It doesn't benefit more compared to (1).
|
||||
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
|
||||
* (Cost) Lifting existing emitters to Standard introduces some challenges.
|
||||
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
|
||||
@ -134,6 +132,19 @@ end state of each XLA op:
|
||||
* (Benefit) unified stack; community support; portability; more
|
||||
optimization potentials.
|
||||
|
||||
Conclusions:
|
||||
|
||||
* Don't go for (2). (1) or (3) are just better than (2). (2) costs more than
|
||||
(1), since it requires a lot of mechanical refactoring. With (1) we can
|
||||
still achieve the goal of enabling XLA to pick up MLIR emitters. This is by
|
||||
doing LHLO -> LLVM IR -> run legacy device emitters.
|
||||
* ElementalIrEmitter ops go for (4), but not incrementally. There is no way to
|
||||
do it op by op, because all elementally-emitted ops are connected into the
|
||||
same graph. This work can also serve as a unification point of several
|
||||
on-going forces (xla/service/mlir\_gpu, the kernel generator, Linalg).
|
||||
* All other ops go for (1). As a stretch goal, they might be migrated to (3)
|
||||
or (4).
|
||||
|
||||
## Prioritization
|
||||
|
||||
While all three tasks mentioned above are parallelizable, under limited
|
||||
@ -210,26 +221,19 @@ The exact profiling can't be easily done for MLIR-generated ops, since:
|
||||
|
||||
### Step 3: (Task 2) Migrating Thunks
|
||||
|
||||
This step migrates all host ops and library calls. This step will eliminate most
|
||||
of the thunks and produce serializable MLIR instead.
|
||||
|
||||
There are roughly three kinds of thunks:
|
||||
|
||||
As a note, there are roughly three kinds of thunks:
|
||||
* KernelThunk, which launches a kernel.
|
||||
* Control flow thunks, which has host control flow logic (conditional, while,
|
||||
for, sequence) and launch body kernels.
|
||||
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
|
||||
|
||||
The **bottom line** is to:
|
||||
The plan is:
|
||||
* Make Thunks (de)serializable.
|
||||
* Help improve TFRT to a state where it can support these semantics.
|
||||
* As the state improves, migrate individual thunks incrementally.
|
||||
|
||||
* Create a Thunk dialect that provides (de)serialize logic for all existing
|
||||
C++-based Thunks.
|
||||
* Change emitters to emit a graph of Thunk dialect.
|
||||
|
||||
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
|
||||
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
|
||||
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
|
||||
step requires profiling and stream support.
|
||||
These action items are only partially ordered. The actual execution order /
|
||||
engineering parallelism is to be evaluated as it goes.
|
||||
|
||||
### Step 4: (Task 3) Migrated ElementalIrEmitter
|
||||
|
||||
|
@ -568,6 +568,26 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "legalize_gather_to_torch_index_select",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc"],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "legalize_tanh_to_approximation",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc"],
|
||||
@ -717,6 +737,7 @@ cc_library(
|
||||
":hlo_dialect_registration",
|
||||
":hlo_legalize_to_lhlo",
|
||||
":legalize_control_flow",
|
||||
":legalize_gather_to_torch_index_select",
|
||||
":legalize_tanh_to_approximation",
|
||||
":legalize_to_linalg",
|
||||
":legalize_to_standard",
|
||||
|
@ -41,6 +41,10 @@ void PopulateComplexLoweringPatterns(MLIRContext *context,
|
||||
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
// Rewrite patterns for gather to equivalent torch index select legalization.
|
||||
void PopulateGatherToTorchIndexSelectPatterns(
|
||||
mlir::MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
|
@ -1468,7 +1468,7 @@ static LogicalResult Verify(PadOp op) {
|
||||
|
||||
static LogicalResult Verify(ReshapeOp op) {
|
||||
// If the operand type is dynamically shaped there is nothing to verify.
|
||||
auto operand_ty = op.operand().getType().cast<RankedTensorType>();
|
||||
auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||
if (!operand_ty || !operand_ty.hasStaticShape()) return success();
|
||||
|
||||
// If the operand type is statically shaped (not required) the number of
|
||||
|
@ -0,0 +1,152 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
|
||||
using OpRewritePattern<GatherOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GatherOp gather,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto start_indices = gather.start_indices();
|
||||
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
|
||||
if (!start_indices_ty.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto operand = gather.operand();
|
||||
auto operand_ty = operand.getType().cast<ShapedType>();
|
||||
if (!operand_ty.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
int64_t index_vector_dim =
|
||||
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
|
||||
|
||||
// We can use torch_index_select if the last dimension represents the
|
||||
// gather indices.
|
||||
auto dimension_numbers = gather.dimension_numbers();
|
||||
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
|
||||
index_vector_dim) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Index select only works across a single dimension.
|
||||
if (!start_indices_ty.getShape().empty() &&
|
||||
start_indices_ty.getShape().back() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Only support the default case for start_index_map.
|
||||
if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
|
||||
dimension_numbers.start_index_map()
|
||||
.getValue(0)
|
||||
.cast<IntegerAttr>()
|
||||
.getValue() != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_ty) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Offset dimensions should be the defaults.
|
||||
if (dimension_numbers.offset_dims().getType().getNumElements() !=
|
||||
result_ty.getRank() - index_vector_dim) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
|
||||
if ((it.index() + index_vector_dim) != it.value()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
|
||||
// First shape value must be 1.
|
||||
if (it.index() == 0) {
|
||||
if (it.value().getSExtValue() != 1) {
|
||||
return failure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// The gather needs to index the entire slice for each other dimension.
|
||||
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 4> index_select_shape =
|
||||
llvm::to_vector<4>(start_indices_ty.getShape());
|
||||
|
||||
for (auto dim : operand_ty.getShape().drop_front()) {
|
||||
index_select_shape.push_back(dim);
|
||||
}
|
||||
|
||||
if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
|
||||
dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
|
||||
1 ||
|
||||
dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
|
||||
gather.getLoc(),
|
||||
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
|
||||
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
|
||||
torch_index_select);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct LegalizeGatherToTorchIndexSelect
|
||||
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> {
|
||||
/// Perform the lowering of standard dialect operations to approximations.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void PopulateGatherToTorchIndexSelectPatterns(
|
||||
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<GatherIsTorchIndexSelect>(context);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass(
|
||||
"mhlo-legalize-gather-to-torch-index-select",
|
||||
"Legalizes gathers to a torch index select.");
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
@ -0,0 +1,41 @@
|
||||
// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @gather_to_index_select
|
||||
func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> {
|
||||
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
// CHECK-SAME: batch_dims = 0 : i64,
|
||||
// CHECK-SAME: dim = 0 : i64
|
||||
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32>
|
||||
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %0 : tensor<1x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @scalar_gather_to_index_select
|
||||
func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<i32>) -> tensor<1x4xf32> {
|
||||
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
|
||||
// CHECK-SAME: batch_dims = 0 : i64,
|
||||
// CHECK-SAME: dim = 0 : i64
|
||||
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<i32>) -> tensor<4xf32>
|
||||
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<i32>) -> tensor<1x4xf32>
|
||||
|
||||
// CHECK: return [[RES]]
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_no_lowering_subslice
|
||||
func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> {
|
||||
// CHECK: "mhlo.gather"
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32>
|
||||
return %0 : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @gather_no_lowering_multidim
|
||||
func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> {
|
||||
// CHECK: "mhlo.gather"
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32>
|
||||
return %0 : tensor<1x3x4xf32>
|
||||
}
|
@ -220,18 +220,14 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LoopLikeInterface",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -349,7 +345,6 @@ cc_library(
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":lstm_utils",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
@ -369,7 +364,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:tensor_list",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
@ -400,7 +394,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -434,7 +427,6 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
@ -457,7 +449,6 @@ cc_library(
|
||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -609,8 +600,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -651,7 +640,6 @@ cc_library(
|
||||
":flatbuffer_tflite_operator_lib",
|
||||
":tensorflow_lite",
|
||||
":tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -724,7 +712,6 @@ cc_library(
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
@ -858,10 +845,8 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
@ -151,10 +151,13 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
return errors::Unimplemented("Only support a single exported name.");
|
||||
}
|
||||
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.upgrade_legacy = true;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ImportSavedModel(model_flags.saved_model_dir(),
|
||||
model_flags.saved_model_version(), tags,
|
||||
exported_names, &context));
|
||||
exported_names, specs, &context));
|
||||
|
||||
if (!model_flags.input_arrays().empty() ||
|
||||
!model_flags.output_arrays().empty()) {
|
||||
|
@ -81,7 +81,6 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
|
@ -144,6 +144,10 @@ int main(int argc, char **argv) {
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> module;
|
||||
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.upgrade_legacy = upgrade_legacy;
|
||||
specs.prune_unused_nodes = true;
|
||||
|
||||
// TODO(b/147435528): We need to test the e2e behavior once the graph freezing
|
||||
// inside mlir is done.
|
||||
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
||||
@ -168,12 +172,10 @@ int main(int argc, char **argv) {
|
||||
return kTrFailure;
|
||||
}
|
||||
|
||||
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
||||
tags, exported_names, &context);
|
||||
module =
|
||||
tensorflow::ImportSavedModel(input_file_name, saved_model_version, tags,
|
||||
exported_names, specs, &context);
|
||||
} else {
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.upgrade_legacy = upgrade_legacy;
|
||||
specs.prune_unused_nodes = true;
|
||||
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
||||
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
||||
specs, debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
|
@ -186,7 +186,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
const std::string& input_filename, const int saved_model_version,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
absl::Span<std::string> exported_names, const GraphImportConfig& specs,
|
||||
mlir::MLIRContext* context) {
|
||||
if (saved_model_version == 2) {
|
||||
auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
@ -194,7 +195,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
return module_or.ConsumeValueOrDie();
|
||||
} else if (saved_model_version == 1) {
|
||||
auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, exported_names, context);
|
||||
input_filename, tags, exported_names, context, specs.upgrade_legacy);
|
||||
|
||||
if (!module_or.status().ok()) return module_or.status();
|
||||
return module_or.ConsumeValueOrDie();
|
||||
|
@ -48,7 +48,8 @@ LoadFromGraphdefOrMlirSource(
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
const std::string& input_filename, const int saved_model_version,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||
absl::Span<std::string> exported_names, const GraphImportConfig& specs,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Taking a MLIR module in TF executor dialect and a set of parameters,
|
||||
// applies a set of passes to convert the module to TF Lite dialect and
|
||||
|
@ -37,22 +37,19 @@ class HasRankAtMost<int n> : Constraint<
|
||||
// Multi-pattern consisting of matching stand-alone convolution op followed by
|
||||
// activation op.
|
||||
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
|
||||
def : Pat<(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias,
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w)),
|
||||
(TFL_Conv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(HasOneUse $conv_out)]>;
|
||||
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias,
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
[(HasOneUse $conv_out)]>;
|
||||
def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat<
|
||||
(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor,
|
||||
$w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)),
|
||||
(TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(HasOneUse $conv_out)]>;
|
||||
def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat<
|
||||
(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor,
|
||||
$w_factor, TFL_AF_None, $padding, $stride_h, $stride_w,
|
||||
$multiplier)),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor,
|
||||
ActFnAttr, $padding, $stride_h, $stride_w, $multiplier),
|
||||
[(HasOneUse $conv_out)]>;
|
||||
}
|
||||
|
||||
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
|
||||
@ -73,33 +70,29 @@ class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
|
||||
// constant folding the bias and the binary op's constant operand. The following
|
||||
// pattern restricts to float constant values for now.
|
||||
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
|
||||
def : Pat<(binaryOp (TFL_Conv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_Conv2DOp $input, $filter,
|
||||
(binaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter,
|
||||
(binaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value),
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
def FuseBinaryOpWithConv#binaryOp : Pat<
|
||||
(binaryOp (TFL_Conv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor,
|
||||
TFL_AF_None, $padding, $stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_Conv2DOp $input, $filter,
|
||||
(binaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value), TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
|
||||
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
|
||||
(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
|
||||
$stride_w, $multiplier),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter,
|
||||
(binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
}
|
||||
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
|
||||
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
|
||||
@ -116,43 +109,43 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall<
|
||||
// The following pattern restricts to float constant values for now.
|
||||
|
||||
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
|
||||
def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_DepthwiseConv2DOp $input,
|
||||
(BinaryOp (ConstantOp $filter),
|
||||
(ConstantOp
|
||||
(ExpandTo4DForDepthwiseConv $value)),
|
||||
TFL_AF_None),
|
||||
(BinaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value),
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier),
|
||||
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
def : Pat<(BinaryOp (TFL_Conv2DOp:$conv_output $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_Conv2DOp $input,
|
||||
(BinaryOp (ConstantOp $filter),
|
||||
(ConstantOp (ExpandTo4DForConv $value)),
|
||||
TFL_AF_None),
|
||||
(BinaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value),
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
|
||||
(HasOneUse $conv_output)]>;
|
||||
def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
|
||||
$stride_w, $multiplier),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_DepthwiseConv2DOp $input,
|
||||
(BinaryOp
|
||||
(ConstantOp $filter),
|
||||
(ConstantOp (ExpandTo4DForDepthwiseConv $value)),
|
||||
TFL_AF_None),
|
||||
(BinaryOp
|
||||
(ConstantOp $bias),
|
||||
(ConstantOp $value),
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn, $padding, $stride_h,
|
||||
$stride_w, $multiplier),
|
||||
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
|
||||
(HasOneUse $output)]>;
|
||||
def FuseMulOrDivWithConv#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_Conv2DOp:$conv_output $input,
|
||||
(ConstantOp F32ElementsAttr:$filter),
|
||||
(ConstantOp F32ElementsAttr:$bias),
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w),
|
||||
(ConstantOp F32ElementsAttr:$value), $act_fn),
|
||||
(TFL_Conv2DOp $input,
|
||||
(BinaryOp (ConstantOp $filter),
|
||||
(ConstantOp (ExpandTo4DForConv $value)),
|
||||
TFL_AF_None),
|
||||
(BinaryOp (ConstantOp $bias),
|
||||
(ConstantOp $value),
|
||||
TFL_AF_None),
|
||||
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
|
||||
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
|
||||
(HasOneUse $conv_output)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
|
||||
@ -177,7 +170,7 @@ class OperandHasRank<int n> : Constraint<
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
|
||||
|
||||
// Matching HardSwish
|
||||
def : Pat<
|
||||
def MatchHardSwishPattern1 : Pat<
|
||||
(TFL_MulOp
|
||||
(TFL_MulOp
|
||||
$x, (TFL_AddOp
|
||||
@ -190,7 +183,7 @@ def : Pat<
|
||||
(TFL_HardSwishOp $x),
|
||||
[(EqualOperands $x, $y)]>;
|
||||
|
||||
def : Pat<
|
||||
def MatchHardSwishPattern2 : Pat<
|
||||
(TFL_MulOp
|
||||
$x,
|
||||
(TFL_MulOp
|
||||
@ -207,7 +200,7 @@ def : Pat<
|
||||
// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
|
||||
// incorrect placement in the quantization aware training.
|
||||
// TODO(b/149735743): We should make the placement automatically.
|
||||
def : Pat<
|
||||
def MatchHardSwishQuantized : Pat<
|
||||
(TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
|
||||
(TFL_MulOp
|
||||
$x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
|
||||
@ -238,7 +231,8 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
|
||||
// This pattern constructs L2NormalizationOp from
|
||||
// Mul->Rsqrt->Sum->Square Or
|
||||
// Div->sqrt->Sum->Square
|
||||
def : Pat<(FirstOp $operand1,
|
||||
def L2NormalizePattern1#FirstOp#SecondOp : Pat<
|
||||
(FirstOp $operand1,
|
||||
(SecondOp
|
||||
(TFL_SumOp
|
||||
(TFL_SquareOp:$sq_op $square_operand),
|
||||
@ -251,7 +245,8 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
|
||||
|
||||
// Below patterns for L2Normalize when there is an Add or Maximum
|
||||
// adding or clamping to a small constant scalar.
|
||||
def : Pat<(FirstOp $operand1,
|
||||
def L2NormalizePattern2#FirstOp#SecondOp : Pat<
|
||||
(FirstOp $operand1,
|
||||
(SecondOp
|
||||
(TFL_AddOp
|
||||
(TFL_SumOp
|
||||
@ -265,7 +260,8 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
|
||||
(L2NormValidReduceIndex $sq_op, $axis),
|
||||
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
|
||||
|
||||
def : Pat<(FirstOp $operand1,
|
||||
def L2NormalizePattern3#FirstOp#SecondOp : Pat<
|
||||
(FirstOp $operand1,
|
||||
(SecondOp
|
||||
(TFL_MaximumOp
|
||||
(TFL_SumOp
|
||||
@ -302,14 +298,16 @@ def HaveSameType : Constraint<CPred<"$0.getType(), $1.getType()">>;
|
||||
// Pattern for skipping Tile if it is mainly for broadcasting and the
|
||||
// Op is already supporting broadcasting.
|
||||
multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
||||
def : Pat<(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
|
||||
$operand, $act_func),
|
||||
(BinaryOp $input, $operand, $act_func),
|
||||
def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat<
|
||||
(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
|
||||
$operand, $act_func),
|
||||
(BinaryOp $input, $operand, $act_func),
|
||||
[(OperandsBroadcastToOutputType $input, $operand, $result)]>;
|
||||
|
||||
def : Pat<(BinaryOp:$result $operand,
|
||||
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
|
||||
(BinaryOp $operand, $input, $act_func),
|
||||
def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
|
||||
(BinaryOp:$result $operand,
|
||||
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
|
||||
(BinaryOp $operand, $input, $act_func),
|
||||
[(OperandsBroadcastToOutputType $operand, $input, $result)]>;
|
||||
}
|
||||
|
||||
@ -318,9 +316,10 @@ multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1]),
|
||||
[(HasOneUse $binary_out)]>;
|
||||
def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat<
|
||||
(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1]),
|
||||
[(HasOneUse $binary_out)]>;
|
||||
}
|
||||
}
|
||||
|
||||
@ -340,21 +339,22 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
||||
// transformation, the shape of the binary op result is [40x1600], which
|
||||
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
|
||||
// make sure $rhs is the tail shape of $lhs.
|
||||
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a), TFL_AF_None),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same, so we know the
|
||||
// result and input of the "BinaryOp" in the source pattern have
|
||||
// the same shape, which is defined by `shape`.
|
||||
[(IsTailOfShape $rhs, $lhs),
|
||||
(HasOneUse $lhs),
|
||||
// The result of the new "BinaryOp" will have the same shape as
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input),
|
||||
(HasRankAtMost<5> $input),
|
||||
(HasRankAtMost<5> $rhs)]>;
|
||||
def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a), TFL_AF_None),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same, so we know the
|
||||
// result and input of the "BinaryOp" in the source pattern have
|
||||
// the same shape, which is defined by `shape`.
|
||||
[(IsTailOfShape $rhs, $lhs),
|
||||
(HasOneUse $lhs),
|
||||
// The result of the new "BinaryOp" will have the same shape as
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input),
|
||||
(HasRankAtMost<5> $input),
|
||||
(HasRankAtMost<5> $rhs)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
@ -370,19 +370,20 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
// transformation, the shape of the binary op result is [40x1600], which
|
||||
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
|
||||
// make sure $rhs is the tail shape of $lhs.
|
||||
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a)),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same, so we know the
|
||||
// result and input of the "BinaryOp" in the source pattern have
|
||||
// the same shape, which is defined by `shape`.
|
||||
[(IsTailOfShape $rhs, $lhs),
|
||||
(HasOneUse $lhs),
|
||||
// The result of the new "BinaryOp" will have the same shape as
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input)]>;
|
||||
def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a)),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same, so we know the
|
||||
// result and input of the "BinaryOp" in the source pattern have
|
||||
// the same shape, which is defined by `shape`.
|
||||
[(IsTailOfShape $rhs, $lhs),
|
||||
(HasOneUse $lhs),
|
||||
// The result of the new "BinaryOp" will have the same shape as
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input)]>;
|
||||
}
|
||||
|
||||
// Reorder the element-wise value operations and the element move operations,
|
||||
@ -392,9 +393,10 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
|
||||
TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp] in {
|
||||
foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp,
|
||||
TFL_ReshapeOp, TFL_TransposeOp] in {
|
||||
def : Pat<(ValueOp:$value (MoveOp:$move $input, $move_def)),
|
||||
(MoveOp (ValueOp $input), $move_def),
|
||||
[(HasOneUse $move)]>;
|
||||
def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat<
|
||||
(ValueOp:$value (MoveOp:$move $input, $move_def)),
|
||||
(MoveOp (ValueOp $input), $move_def),
|
||||
[(HasOneUse $move)]>;
|
||||
}
|
||||
}
|
||||
|
||||
@ -403,16 +405,16 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
|
||||
def GetShape: NativeCodeCall<"GetShape($0)">;
|
||||
|
||||
// Convert squeeze to reshape if possible.
|
||||
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
|
||||
(TFL_ReshapeOp $input,
|
||||
(ConstantOp (GetShape $squeeze_op))),
|
||||
[(AnyStaticShapeTensor $squeeze_op)]>;
|
||||
def ConvertSqueezeToReshape : Pat<
|
||||
(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
|
||||
(TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))),
|
||||
[(AnyStaticShapeTensor $squeeze_op)]>;
|
||||
|
||||
// Convert expand_dims to reshape if possible.
|
||||
def : Pat<(TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
|
||||
(TFL_ReshapeOp $input,
|
||||
(ConstantOp (GetShape $expand_dims_op))),
|
||||
[(AnyStaticShapeTensor $expand_dims_op)]>;
|
||||
def ConvertExpandDimsToReshape : Pat<
|
||||
(TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
|
||||
(TFL_ReshapeOp $input, (ConstantOp (GetShape $expand_dims_op))),
|
||||
[(AnyStaticShapeTensor $expand_dims_op)]>;
|
||||
|
||||
class FloatValueEquals<string val> : Constraint<CPred<
|
||||
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
||||
@ -420,25 +422,27 @@ class FloatValueEquals<string val> : Constraint<CPred<
|
||||
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||
|
||||
// ReLU patterns
|
||||
def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
|
||||
(ConstantOp $NegOne)),
|
||||
(ConstantOp $One)),
|
||||
(TFL_Relu1Op $input),
|
||||
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
|
||||
def MatchRelu1Pattern1 : Pat<
|
||||
(TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)),
|
||||
(ConstantOp $One)),
|
||||
(TFL_Relu1Op $input),
|
||||
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
|
||||
|
||||
def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||
(ConstantOp $One)),
|
||||
(ConstantOp $NegOne)),
|
||||
(TFL_Relu1Op $input),
|
||||
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
|
||||
def MatchRelu1Pattern2 : Pat<
|
||||
(TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)),
|
||||
(ConstantOp $NegOne)),
|
||||
(TFL_Relu1Op $input),
|
||||
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
|
||||
|
||||
def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
|
||||
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
|
||||
$input2),
|
||||
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
|
||||
[(ConstDoubleValueLessThan<"1"> $alpha),
|
||||
(EqualOperands $input1, $input2),
|
||||
(HasOneUse $mul_out)]>;
|
||||
def MatchLeakyRelu : Pat<
|
||||
(TFL_MaximumOp
|
||||
(TFL_MulOp:$mul_out $input1,
|
||||
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
|
||||
$input2),
|
||||
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
|
||||
[(ConstDoubleValueLessThan<"1"> $alpha),
|
||||
(EqualOperands $input1, $input2),
|
||||
(HasOneUse $mul_out)]>;
|
||||
|
||||
def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input),
|
||||
(replaceWithValue $input),
|
||||
@ -451,23 +455,25 @@ def PReluAlphaRankCheck : Constraint<
|
||||
|
||||
// PReLU pattern from Keras:
|
||||
// f(x) = Relu(x) + (-alpha * Relu(-x))
|
||||
def : Pat<(TFL_AddOp
|
||||
(TFL_ReluOp:$relu_out $input1),
|
||||
(TFL_MulOp:$mul_out
|
||||
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
|
||||
$neg_alpha,
|
||||
TFL_AF_None),
|
||||
TFL_AF_None),
|
||||
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
|
||||
[(EqualOperands $input1, $input2),
|
||||
(PReluAlphaRankCheck $neg_alpha, $input1),
|
||||
(HasOneUse $relu_out),
|
||||
(HasOneUse $mul_out),
|
||||
(HasOneUse $input_neg_out)]>;
|
||||
def MatchPRelu : Pat<
|
||||
(TFL_AddOp
|
||||
(TFL_ReluOp:$relu_out $input1),
|
||||
(TFL_MulOp:$mul_out
|
||||
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
|
||||
$neg_alpha,
|
||||
TFL_AF_None),
|
||||
TFL_AF_None),
|
||||
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
|
||||
[(EqualOperands $input1, $input2),
|
||||
(PReluAlphaRankCheck $neg_alpha, $input1),
|
||||
(HasOneUse $relu_out),
|
||||
(HasOneUse $mul_out),
|
||||
(HasOneUse $input_neg_out)]>;
|
||||
|
||||
// The constant folding in this pass might produce constant in the tf dialect.
|
||||
// This rule is to legalize these constant to the tfl dialect.
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
def LegalizeConstOp : Pat<
|
||||
(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
||||
// Reorders adds to allow constant folding.
|
||||
// Add --> Add $input, $constantA
|
||||
@ -476,13 +482,14 @@ def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
// Add --> $input
|
||||
// \--> Add ($constantA, $constantB)
|
||||
foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
|
||||
def : Pat<(TFL_AddOp
|
||||
(TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
|
||||
(ConstantOp $b), ActFun),
|
||||
(TFL_AddOp $input,
|
||||
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
|
||||
ActFun),
|
||||
[(HasOneUse $first_output)]>;
|
||||
def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat<
|
||||
(TFL_AddOp
|
||||
(TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
|
||||
(ConstantOp $b), ActFun),
|
||||
(TFL_AddOp $input,
|
||||
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
|
||||
ActFun),
|
||||
[(HasOneUse $first_output)]>;
|
||||
}
|
||||
|
||||
// We can eliminate Relu from Relu(SquaredDifference(x, y)),
|
||||
|
@ -143,8 +143,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
type = FunctionType::get(types, result_types, &getContext());
|
||||
}
|
||||
|
||||
auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type,
|
||||
ArrayRef<NamedAttribute>{});
|
||||
auto outlined_func = builder.create<FuncOp>(while_op.getLoc(), name, type);
|
||||
outlined_func.getBody().takeBody(region);
|
||||
Region& func_region = outlined_func.getBody();
|
||||
|
||||
|
@ -301,8 +301,8 @@ bool IslandOp::WrapsSingleOp() {
|
||||
namespace {
|
||||
|
||||
LogicalResult Verify(IslandOp island) {
|
||||
if (island.GetBody().empty())
|
||||
return island.emitOpError() << "expects a non-empty body";
|
||||
if (!island.GetBody().args_empty())
|
||||
return island.emitOpError() << "expects body without any arguments";
|
||||
|
||||
Operation &yield = island.GetBody().back();
|
||||
if (!isa<YieldOp>(yield))
|
||||
|
@ -232,6 +232,20 @@ func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
|
||||
|
||||
// -----
|
||||
|
||||
// Check that an island body doesn't have any block arguments.
|
||||
func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
|
||||
tf_executor.graph {
|
||||
"tf_executor.island"() ({
|
||||
// expected-error@-1 {{expects body without any arguments}}
|
||||
^entry(%arg: tensor<2xi32>):
|
||||
tf_executor.yield
|
||||
}) : () -> (!tf_executor.control)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that an island body can't be empty.
|
||||
func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) {
|
||||
tf_executor.graph {
|
||||
|
@ -9,16 +9,15 @@
|
||||
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
|
||||
func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
|
||||
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>) -> () {
|
||||
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
|
||||
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
// CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {_xla_outside_compilation = "0", value = dense<"inference"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
%2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
|
||||
return
|
||||
}
|
||||
@ -34,20 +33,19 @@ func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x
|
||||
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
|
||||
func @check_enqueue_ops_update_for_training(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
|
||||
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>) -> () {
|
||||
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
|
||||
|
||||
%2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
|
||||
%3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
"tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> ()
|
||||
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
// CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {_xla_outside_compilation = "0", value = dense<"train"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
%4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
|
||||
return
|
||||
}
|
||||
|
@ -77,8 +77,7 @@ Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder,
|
||||
ArrayRef<Type>{RankedTensorType::get(
|
||||
{static_cast<int64_t>(buffer_type.getShape().size())},
|
||||
getElementTypeOrSelf(index.getType()))},
|
||||
ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)});
|
||||
}
|
||||
|
||||
Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc,
|
||||
@ -95,15 +94,14 @@ Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc,
|
||||
auto slice = builder.create<TF::SliceOp>(
|
||||
loc, ArrayRef<Type>{slice_type},
|
||||
ArrayRef<Value>{buffer, GetIndicesForElement(index, buffer, builder, loc),
|
||||
size_const},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
size_const});
|
||||
if (keep_slice_shape) return slice;
|
||||
auto element_type = RankedTensorType::get(buffer_type.getShape().drop_front(),
|
||||
buffer_type.getElementType());
|
||||
auto reshape = builder.create<TF::ReshapeOp>(
|
||||
loc, ArrayRef<Type>{element_type},
|
||||
ArrayRef<Value>{slice, GetR1Const(element_type.getShape(), builder, loc)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{slice,
|
||||
GetR1Const(element_type.getShape(), builder, loc)});
|
||||
return reshape.output();
|
||||
}
|
||||
|
||||
@ -120,15 +118,13 @@ Value SetElement(Value index, Value buffer, Value element, OpBuilder builder,
|
||||
if (element.getType() != slice_type) {
|
||||
update_slice = builder.create<TF::ReshapeOp>(
|
||||
loc, ArrayRef<Type>{slice_type},
|
||||
ArrayRef<Value>{element, GetR1Const(slice_shape, builder, loc)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{element, GetR1Const(slice_shape, builder, loc)});
|
||||
}
|
||||
return builder
|
||||
.create<TF::XlaDynamicUpdateSliceOp>(
|
||||
loc, ArrayRef<Type>{buffer.getType()},
|
||||
ArrayRef<Value>{buffer, update_slice,
|
||||
GetIndicesForElement(index, buffer, builder, loc)},
|
||||
ArrayRef<NamedAttribute>{})
|
||||
GetIndicesForElement(index, buffer, builder, loc)})
|
||||
.output();
|
||||
}
|
||||
|
||||
@ -140,8 +136,7 @@ Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc) {
|
||||
auto size_type = GetSizeType(builder);
|
||||
return builder.create<TF::ReshapeOp>(
|
||||
loc, ArrayRef<Type>{size_type},
|
||||
ArrayRef<Value>{scalar, GetR1Const(size_type.getShape(), builder, loc)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{scalar, GetR1Const(size_type.getShape(), builder, loc)});
|
||||
}
|
||||
|
||||
LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
|
||||
@ -171,13 +166,12 @@ LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
|
||||
if (getElementTypeOrSelf(zero.getType()) != element_dtype) {
|
||||
zero = builder.create<TF::CastOp>(
|
||||
op->getLoc(), ArrayRef<Type>{RankedTensorType::get({}, element_dtype)},
|
||||
ArrayRef<Value>{zero}, ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{zero});
|
||||
}
|
||||
auto buffer_type = RankedTensorType::get(buffer_shape, element_dtype);
|
||||
auto broadcast = builder.create<TF::BroadcastToOp>(
|
||||
op->getLoc(), ArrayRef<Type>{buffer_type},
|
||||
ArrayRef<Value>{zero, GetR1Const(buffer_shape, builder, op->getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{zero, GetR1Const(buffer_shape, builder, op->getLoc())});
|
||||
*buffer = broadcast.output();
|
||||
return success();
|
||||
}
|
||||
@ -241,27 +235,24 @@ Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) {
|
||||
ArrayRef<Type>{getElementTypeOrSelf(local_var.getType())
|
||||
.cast<TF::ResourceType>()
|
||||
.getSubtypes()[0]},
|
||||
ArrayRef<Value>{local_var}, ArrayRef<NamedAttribute>{})
|
||||
ArrayRef<Value>{local_var})
|
||||
.value();
|
||||
}
|
||||
|
||||
// Creates an AssignVariableOp on a local variable.
|
||||
TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value,
|
||||
OpBuilder builder, Location loc) {
|
||||
return builder.create<TF::AssignVariableOp>(loc, ArrayRef<Type>{},
|
||||
ArrayRef<Value>{local_var, value},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
return builder.create<TF::AssignVariableOp>(
|
||||
loc, ArrayRef<Type>{}, ArrayRef<Value>{local_var, value});
|
||||
}
|
||||
|
||||
Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc) {
|
||||
if (getElementTypeOrSelf(a.getType()) == builder.getI1Type()) {
|
||||
return builder.create<TF::LogicalOrOp>(loc, ArrayRef<Type>{a.getType()},
|
||||
ArrayRef<Value>{a, b},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{a, b});
|
||||
}
|
||||
return builder.create<TF::AddV2Op>(loc, ArrayRef<Type>{a.getType()},
|
||||
ArrayRef<Value>{a, b},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{a, b});
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -303,15 +294,13 @@ Value GatherElements(Value indices, Value buffer, OpBuilder builder,
|
||||
return builder.create<TF::SliceOp>(
|
||||
loc, ArrayRef<Type>{slice_type},
|
||||
ArrayRef<Value>{buffer, GetR1Const(slice_starts, builder, loc),
|
||||
GetR1Const(result_shape, builder, loc)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
GetR1Const(result_shape, builder, loc)});
|
||||
}
|
||||
auto result_type =
|
||||
RankedTensorType::get(result_shape, buffer_type.getElementType());
|
||||
return builder.create<TF::GatherV2Op>(
|
||||
loc, ArrayRef<Type>{result_type},
|
||||
ArrayRef<Value>{buffer, indices, CreateScalarConst(0, builder, loc)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{buffer, indices, CreateScalarConst(0, builder, loc)});
|
||||
}
|
||||
|
||||
Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
|
||||
@ -334,8 +323,7 @@ Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
|
||||
auto index = builder.create<TF::SliceOp>(
|
||||
loc, ArrayRef<Type>{GetSizeType(builder)},
|
||||
ArrayRef<Value>{indices, GetR1Const({i}, builder, loc),
|
||||
GetR1Const({1}, builder, loc)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
GetR1Const({1}, builder, loc)});
|
||||
auto old_slice =
|
||||
GetElement(index, buffer, builder, loc, /*keep_slice_shape=*/true);
|
||||
starts_in_update[0] = i;
|
||||
@ -344,8 +332,7 @@ Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
|
||||
builder
|
||||
.create<TF::SliceOp>(
|
||||
loc, ArrayRef<Type>{old_slice.getType()},
|
||||
ArrayRef<Value>{updates, update_slice_starts, slice_sizes},
|
||||
ArrayRef<NamedAttribute>{})
|
||||
ArrayRef<Value>{updates, update_slice_starts, slice_sizes})
|
||||
.output();
|
||||
slice = AccumulateBuffers(old_slice, slice, builder, loc);
|
||||
buffer = SetElement(index, buffer, slice, builder, loc);
|
||||
|
@ -185,8 +185,8 @@ IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
|
||||
|
||||
Operation* old_island = insert_position == kParentIsland ? parent : child;
|
||||
OpBuilder builder(old_island);
|
||||
auto new_island = builder.create<IslandOp>(
|
||||
old_island->getLoc(), result_types, operands, ArrayRef<NamedAttribute>{});
|
||||
auto new_island =
|
||||
builder.create<IslandOp>(old_island->getLoc(), result_types, operands);
|
||||
new_island.body().push_back(new Block);
|
||||
return new_island;
|
||||
}
|
||||
|
@ -105,8 +105,8 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
||||
// Create the outlined function
|
||||
SmallString<32> name = kOutlinedFuncPrefix;
|
||||
name += llvm::Twine(prefix_id++).str();
|
||||
auto outlined_func = OpBuilder(ctx).create<FuncOp>(
|
||||
island_op.getLoc(), name, func_type, ArrayRef<NamedAttribute>());
|
||||
auto outlined_func =
|
||||
OpBuilder(ctx).create<FuncOp>(island_op.getLoc(), name, func_type);
|
||||
outlined_symbol_table.insert(outlined_func);
|
||||
|
||||
// We will "steal" the body of the island and replace it with a call to the
|
||||
|
@ -190,7 +190,7 @@ tf_executor::IslandOp CreateOutputBarrierIsland(
|
||||
builder->setInsertionPoint(island_op);
|
||||
auto island_output_sink = builder->create<tf_executor::IslandOp>(
|
||||
island_op.getLoc(), llvm::to_vector<8>(island_op.getResultTypes()),
|
||||
island_operands, llvm::ArrayRef<NamedAttribute>{});
|
||||
island_operands);
|
||||
island_output_sink.body().push_back(new Block);
|
||||
return island_output_sink;
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
|
||||
builder.setInsertionPoint(user);
|
||||
ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
|
||||
user->getLoc(), ArrayRef<Type>{tensor_type},
|
||||
ArrayRef<Value>{var_handle_op}, ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{var_handle_op});
|
||||
user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
|
||||
user->erase();
|
||||
}
|
||||
|
@ -124,8 +124,7 @@ void ExtractSingleBlockRegion(Region& region, StringRef name,
|
||||
auto type = FunctionType::get(input_types, return_types, region.getContext());
|
||||
|
||||
// Create new function and extract region body into the function.
|
||||
auto outlined_func =
|
||||
builder.create<FuncOp>(loc, name, type, ArrayRef<NamedAttribute>{});
|
||||
auto outlined_func = builder.create<FuncOp>(loc, name, type);
|
||||
Region& func_region = outlined_func.getBody();
|
||||
func_region.takeBody(region);
|
||||
Block& first_block = func_region.front();
|
||||
|
@ -558,15 +558,13 @@ void AddLoadsStoresOutsideControlFlowOp(
|
||||
auto operand = caller->getOperand(index);
|
||||
builder.setInsertionPoint(caller);
|
||||
new_operands[index] = builder.create<TF::ReadVariableOp>(
|
||||
caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand});
|
||||
caller->setOperand(index, new_operands[index]);
|
||||
if (updated_index < 0) continue;
|
||||
builder.setInsertionPointAfter(caller);
|
||||
builder.create<TF::AssignVariableOp>(
|
||||
caller->getLoc(), ArrayRef<Type>{},
|
||||
ArrayRef<Value>{operand, caller->getResult(updated_index)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{operand, caller->getResult(updated_index)});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -409,11 +409,9 @@ LogicalResult HandleStackV2Op(
|
||||
ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
|
||||
stack.getContext()));
|
||||
auto local_var = builder.create<TF::MlirLocalVarOp>(
|
||||
stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
|
||||
auto local_size_var = builder.create<TF::MlirLocalVarOp>(
|
||||
stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{});
|
||||
// Zero-initialize the local vars.
|
||||
cutil::WriteLocalVariable(local_size_var,
|
||||
cutil::GetR1Const({0LL}, builder, stack.getLoc()),
|
||||
@ -446,8 +444,7 @@ LogicalResult HandleStackPushV2Op(
|
||||
cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
|
||||
index = builder.create<TF::AddV2Op>(
|
||||
push.getLoc(), ArrayRef<Type>{index.getType()},
|
||||
ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())});
|
||||
cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
|
||||
push.erase();
|
||||
return success();
|
||||
@ -467,8 +464,7 @@ LogicalResult HandleStackPopV2Op(
|
||||
auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
|
||||
auto new_size = builder.create<TF::SubOp>(
|
||||
pop.getLoc(), ArrayRef<Type>{size.getType()},
|
||||
ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())});
|
||||
auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc());
|
||||
pop.replaceAllUsesWith(pop_val);
|
||||
// Update the size.
|
||||
|
@ -166,8 +166,7 @@ LogicalResult HandleTensorArrayV3Op(
|
||||
ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
|
||||
ta.getContext()));
|
||||
auto local_var = builder.create<TF::MlirLocalVarOp>(
|
||||
ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
|
||||
cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc());
|
||||
ta.handle().replaceAllUsesWith(local_var);
|
||||
// The flow output is just a way for the front end to enforce ordering among
|
||||
@ -227,8 +226,7 @@ LogicalResult HandleTensorArrayWriteV3Op(
|
||||
elem = builder.create<TF::ReshapeOp>(
|
||||
write.getLoc(), ArrayRef<Type>{slice_type},
|
||||
ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
|
||||
write.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
write.getLoc())});
|
||||
elem =
|
||||
cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc());
|
||||
}
|
||||
@ -261,8 +259,7 @@ LogicalResult HandleTensorArrayConcatV3Op(
|
||||
ArrayRef<Type>{
|
||||
RankedTensorType::get(shape, buffer_type.getElementType())},
|
||||
ArrayRef<Value>{buffer,
|
||||
cutil::GetR1Const(shape, builder, concat.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
cutil::GetR1Const(shape, builder, concat.getLoc())});
|
||||
concat.value().replaceAllUsesWith(buffer);
|
||||
|
||||
// Create the lengths as a list of the same value (element size).
|
||||
@ -302,8 +299,7 @@ LogicalResult HandleTensorArraySplitV3Op(
|
||||
buffer_shape, elem_type.getElementType())},
|
||||
ArrayRef<Value>{split.value(),
|
||||
cutil::GetR1Const(buffer_shape, builder,
|
||||
split.getLoc())},
|
||||
ArrayRef<NamedAttribute>{})
|
||||
split.getLoc())})
|
||||
.output();
|
||||
// Accumulate with the old buffer.
|
||||
auto old_buffer =
|
||||
@ -339,8 +335,7 @@ LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
|
||||
Operation* op, Value* var) {
|
||||
OpBuilder builder(op);
|
||||
*var = builder.create<TF::MlirLocalVarOp>(
|
||||
op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{});
|
||||
Value buffer;
|
||||
auto buffer_type = getElementTypeOrSelf(local_var_type)
|
||||
.cast<TF::ResourceType>()
|
||||
|
@ -438,7 +438,7 @@ LogicalResult HandleTensorListFromTensorOp(
|
||||
OpBuilder builder(list);
|
||||
Value buffer = builder.create<TF::IdentityOp>(
|
||||
list.getLoc(), ArrayRef<Type>{list.tensor().getType()},
|
||||
ArrayRef<Value>{list.tensor()}, ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{list.tensor()});
|
||||
auto type = buffer.getType().cast<TensorType>();
|
||||
if (!type.hasStaticShape()) {
|
||||
return list.emitOpError("TensorListFromTensorOp input has unknown shape.");
|
||||
@ -468,8 +468,7 @@ LogicalResult HandleTensorListPushBackOp(
|
||||
cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc());
|
||||
auto new_size = builder.create<TF::AddV2Op>(
|
||||
push.getLoc(), ArrayRef<Type>{size.getType()},
|
||||
ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())});
|
||||
push.output_handle().replaceAllUsesWith(new_buffer);
|
||||
(*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
|
||||
push.erase();
|
||||
@ -491,12 +490,10 @@ LogicalResult HandleTensorListPopBackOp(
|
||||
auto size = it->getSecond().size;
|
||||
OpBuilder builder(pop);
|
||||
auto new_buffer = builder.create<TF::IdentityOp>(
|
||||
pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer});
|
||||
auto new_size = builder.create<TF::SubOp>(
|
||||
pop.getLoc(), ArrayRef<Type>{size.getType()},
|
||||
ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())});
|
||||
auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc());
|
||||
pop.output_handle().replaceAllUsesWith(new_buffer);
|
||||
pop.tensor().replaceAllUsesWith(element);
|
||||
@ -567,8 +564,7 @@ LogicalResult HandleTensorListLengthOp(
|
||||
ArrayRef<Type>{RankedTensorType::get(
|
||||
{}, getElementTypeOrSelf(current_size.getType()))},
|
||||
ArrayRef<Value>{current_size,
|
||||
cutil::GetR1Const({}, builder, length.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
cutil::GetR1Const({}, builder, length.getLoc())});
|
||||
length.length().replaceAllUsesWith(reshape);
|
||||
}
|
||||
length.erase();
|
||||
|
@ -154,8 +154,7 @@ TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch,
|
||||
Value input, OpBuilder* builder) {
|
||||
return builder->create<TF::TPUCopyWithLayoutOp>(
|
||||
execute_launch.getLoc(), llvm::ArrayRef<Type>{input.getType()},
|
||||
llvm::ArrayRef<Value>{input, get_layout.layout()},
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
llvm::ArrayRef<Value>{input, get_layout.layout()});
|
||||
}
|
||||
|
||||
// Performs transformation for a non-replicated input.
|
||||
|
@ -206,8 +206,7 @@ TF::_HostComputeMlirOp CreateHostCompute(
|
||||
device_output_types.push_back(output.getType());
|
||||
SetHostComputeInsertion(builder, cluster_ops, inputs);
|
||||
auto host_compute = builder->create<TF::_HostComputeMlirOp>(
|
||||
tpu_cluster.getLoc(), device_output_types, inputs.getArrayRef(),
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
tpu_cluster.getLoc(), device_output_types, inputs.getArrayRef());
|
||||
host_compute.setAttr(kAncestorsAttr, builder->getArrayAttr({}));
|
||||
host_compute.setAttr(kShapesAttr, builder->getArrayAttr({}));
|
||||
host_compute.setAttr(kKeyAttr, builder->getStringAttr(communication_key));
|
||||
|
@ -473,9 +473,8 @@ LogicalResult BuildExecuteOp(
|
||||
if (failed(result)) return failure();
|
||||
|
||||
// TPUExecute has same output types as cluster_func.
|
||||
*execute_op = builder->create<TF::TPUExecuteOp>(
|
||||
cluster_func.getLoc(), output_types, inputs,
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
*execute_op = builder->create<TF::TPUExecuteOp>(cluster_func.getLoc(),
|
||||
output_types, inputs);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -13,24 +13,29 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
namespace {
|
||||
|
||||
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
|
||||
constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
|
||||
|
||||
struct TPUUpdateEmbeddingEnqueueOpInputs
|
||||
@ -86,7 +91,8 @@ LogicalResult FindTPUEmbeddingOps(
|
||||
LogicalResult UpdateEmbeddingEnqueueOpInput(
|
||||
const llvm::StringMap<Operation*>& enqueue_op_map,
|
||||
const llvm::StringMap<Operation*>& recv_activation_op_map,
|
||||
const llvm::StringMap<Operation*>& send_gradient_op_map) {
|
||||
const llvm::StringMap<Operation*>& send_gradient_op_map,
|
||||
OpBuilder* builder) {
|
||||
for (const auto& it : enqueue_op_map) {
|
||||
const auto& embedding_attr = it.getKey();
|
||||
Operation* embedding_op = it.second;
|
||||
@ -96,21 +102,36 @@ LogicalResult UpdateEmbeddingEnqueueOpInput(
|
||||
<< TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
|
||||
|
||||
// TPU Embedding enqueue ops take different inputs depending on whether
|
||||
// graph is in training mode or in eval/prediction mode. The inputs to the
|
||||
// enqueue ops are present/listed as operands to SelectV2 op. Then branch
|
||||
// operand of the SelectV2 op represents input to take during training
|
||||
// and else branch operand represents input to take during
|
||||
// prediction/evaluation. If SendTPUEmbeddingGradients op exists in the
|
||||
// graph, then graph is in training mode, so correctly forward the input
|
||||
// of SelectV2 op as operand to the TPU embedding enqueue op.
|
||||
// graph is in training mode or in eval/prediction mode. During training,
|
||||
// the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
|
||||
// evaluation or prediction, mode must be set to `inference`.
|
||||
// If SendTPUEmbeddingGradients op exists in the graph, then graph is
|
||||
// in training mode, so create a const op with value `train` use the
|
||||
// output value of the constant as an operand to the TPU embedding
|
||||
// enqueue op.
|
||||
bool is_training = send_gradient_op_map.count(embedding_attr);
|
||||
for (auto enqueue_operand : embedding_op->getOperands()) {
|
||||
if (auto select = llvm::dyn_cast_or_null<TF::SelectV2Op>(
|
||||
enqueue_operand.getDefiningOp())) {
|
||||
enqueue_operand.replaceAllUsesWith(is_training ? select.t()
|
||||
: select.e());
|
||||
}
|
||||
}
|
||||
|
||||
// The last operand of TPUEmbeddingEnqueue ops is the mode which
|
||||
// represents whether graph is in training mode or in evaluation mode.
|
||||
auto& mode_enqueue_operand =
|
||||
embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
|
||||
|
||||
llvm::SmallVector<StringRef, 1> mode_string_value;
|
||||
mode_string_value.emplace_back(is_training ? "train" : "inference");
|
||||
builder->setInsertionPoint(embedding_op);
|
||||
auto enqueue_mode = builder->create<TF::ConstOp>(
|
||||
embedding_op->getLoc(),
|
||||
DenseStringElementsAttr::get(
|
||||
RankedTensorType::get({}, builder->getType<TF::StringType>()),
|
||||
mode_string_value));
|
||||
|
||||
auto outside_compilation_attr =
|
||||
embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
|
||||
if (outside_compilation_attr)
|
||||
enqueue_mode.setAttr(kXlaOutsideCompilationAttr,
|
||||
outside_compilation_attr);
|
||||
|
||||
mode_enqueue_operand.set(enqueue_mode);
|
||||
}
|
||||
|
||||
return success();
|
||||
@ -140,8 +161,9 @@ void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
if (failed(UpdateEmbeddingEnqueueOpInput(
|
||||
enqueue_op_map, recv_activation_op_map, send_gradient_op_map)))
|
||||
if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
|
||||
recv_activation_op_map,
|
||||
send_gradient_op_map, &builder)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -521,8 +521,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
replicate.GetNumReplicatedBlockArguments() - 1));
|
||||
builder.setInsertionPoint(execute_launch);
|
||||
auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
|
||||
execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands,
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands);
|
||||
WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
|
||||
execute_launch.device());
|
||||
|
||||
@ -579,8 +578,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
default_state_key.getResult());
|
||||
// Unformat op.
|
||||
auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
|
||||
while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands,
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands);
|
||||
WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
|
||||
execute_launch.device());
|
||||
builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
|
||||
|
@ -21,10 +21,6 @@ package_group(
|
||||
includes = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/compiler/tests/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
@ -34,7 +30,6 @@ package_group(
|
||||
],
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/compiler/tests/...",
|
||||
"//platforms/xla/tests/neural_nets",
|
||||
],
|
||||
)
|
||||
|
@ -311,7 +311,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
if 'GPU' in self.device:
|
||||
# TODO(b/32333178)
|
||||
self.skipTest('Current implementation of RandomStandardNormal kernel '
|
||||
'is very slow on GPU, and has been blacklisted.')
|
||||
'is very slow on GPU, and has been denylisted.')
|
||||
with self.test_scope():
|
||||
data_format = 'channels_last'
|
||||
conv = convolutional.Conv2D(
|
||||
|
@ -4946,7 +4946,18 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
|
||||
node_def.name());
|
||||
}
|
||||
nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
||||
|
||||
if (!params->use_implicit_batch && tensor->getDimensions().d[1] == -1) {
|
||||
// This check is to make sure that channel dimension is known during
|
||||
// conversion.
|
||||
//
|
||||
// We check this only in explicit batch mode and reject an op with unknown
|
||||
// channel dimension during segmentation. In implicit batch mode we have
|
||||
// known shapes during conversion even though the shapes may not be known
|
||||
// during segmentation (see the actual argument for input_shapes when
|
||||
// ConvertGraphDefToEngine is called from TRTEngineOp::BuildEngine).
|
||||
return errors::InvalidArgument("Channel dimension must be static, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// Check parameter types
|
||||
auto parameter_type = inputs.at(1).weights().TrtDType();
|
||||
if ((parameter_type != nvinfer1::DataType::kFLOAT) &&
|
||||
|
@ -2011,6 +2011,142 @@ TEST_F(OpConverterTest, ConvertConst) {
|
||||
TestConvertConst<DT_UINT64, uint64, int32>(this);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
NodeDef CreateFusedBatchNormOp(DataType tf_type, std::string data_format,
|
||||
bool is_training, float epsilon) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto x = ops::Placeholder(s.WithOpName("x"), tf_type);
|
||||
auto scale = ops::Placeholder(s.WithOpName("scale"), tf_type);
|
||||
auto offset = ops::Placeholder(s.WithOpName("offset"), tf_type);
|
||||
auto mean = ops::Placeholder(s.WithOpName("mean"), tf_type);
|
||||
auto variance = ops::Placeholder(s.WithOpName("variance"), tf_type);
|
||||
typename T::Attrs attrs;
|
||||
attrs.data_format_ = data_format;
|
||||
attrs.is_training_ = is_training;
|
||||
if (epsilon > 0) {
|
||||
attrs.epsilon_ = epsilon;
|
||||
} else {
|
||||
EXPECT_GE(epsilon, 0);
|
||||
}
|
||||
return T(s.WithOpName("my_batchnorm"), x, scale, offset, mean, variance,
|
||||
attrs)
|
||||
.operation.node()
|
||||
->def();
|
||||
}
|
||||
|
||||
TEST_P(OpConverterTest1, ConvertFusedBatchNorm) {
|
||||
using OpFunc = std::function<NodeDef(DataType, std::string, bool, float)>;
|
||||
std::vector<OpFunc> get_node_def_vec{
|
||||
CreateFusedBatchNormOp<ops::FusedBatchNorm>,
|
||||
CreateFusedBatchNormOp<ops::FusedBatchNormV2>,
|
||||
CreateFusedBatchNormOp<ops::FusedBatchNormV3>};
|
||||
|
||||
struct TestParam {
|
||||
std::string data_format;
|
||||
int tensor_input_idx; // Index of an input that will be provided as tensor.
|
||||
bool is_training;
|
||||
float epsilon;
|
||||
Status conversion_status;
|
||||
bool keep_channel_unknown;
|
||||
};
|
||||
|
||||
struct NodeInput {
|
||||
std::string name;
|
||||
std::vector<int> dims;
|
||||
std::vector<float> val;
|
||||
};
|
||||
std::vector<NodeInput> node_input{
|
||||
{"x", {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}},
|
||||
{"scale", {3}, {7, 8, 9}},
|
||||
{"offset", {3}, {10, 20, 30}},
|
||||
{"mean", {3}, {1, 2, 3}},
|
||||
{"variance", {3}, {4, 5, 6}}};
|
||||
|
||||
std::vector<float> expected_output{10.0, 13.495633, 23.574135, 27.148273,
|
||||
37.342354, 41.013527, 30.9738, 34.469433,
|
||||
45.018955, 48.59309, 59.369415, 63.04059};
|
||||
for (auto get_node_def : get_node_def_vec) {
|
||||
NodeDef tmp_node_def = get_node_def(tf_type, "NCHW", true, 0);
|
||||
std::string op_name = tmp_node_def.op();
|
||||
std::vector<TestParam> test_param{
|
||||
{"NHWC", 0, false, 0,
|
||||
errors::Unimplemented(StrCat(
|
||||
op_name, " only supports data_format=NCHW, at my_batchnorm"))},
|
||||
{"NCHW", 0, true, 0,
|
||||
errors::Unimplemented(StrCat(
|
||||
op_name, " only supports is_training=false, at my_batchnorm"))},
|
||||
{"NCHW", 1, false, 0,
|
||||
errors::Unimplemented(StrCat("The input \"scale\" for ", op_name,
|
||||
" must be a constant, at my_batchnorm"))},
|
||||
{"NCHW", 2, false, 0,
|
||||
errors::Unimplemented(StrCat("The input \"offset\" for ", op_name,
|
||||
" must be a constant, at my_batchnorm"))},
|
||||
{"NCHW", 3, false, 0,
|
||||
errors::Unimplemented(StrCat("The input \"mean\" for ", op_name,
|
||||
" must be a constant, at my_batchnorm"))},
|
||||
{"NCHW", 4, false, 0,
|
||||
errors::Unimplemented(StrCat("The input \"variance\" for ", op_name,
|
||||
" must be a constant, at my_batchnorm"))},
|
||||
{"NCHW", 0, false, 0.01}}; // The last one is the only test that runs.
|
||||
if (trt_mode == TrtTestMode::kDynamicShape) {
|
||||
test_param.push_back(
|
||||
{"NCHW", 0, false, 0.01,
|
||||
errors::InvalidArgument(
|
||||
"Channel dimension must be static, at my_batchnorm"),
|
||||
true});
|
||||
}
|
||||
for (auto p : test_param) {
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
get_node_def(tf_type, p.data_format, p.is_training, p.epsilon);
|
||||
for (int i = 0; i < node_input.size(); i++) {
|
||||
if (i == 0 || i == p.tensor_input_idx) {
|
||||
// The first input (x) is always added as a tensor, and it hase shape
|
||||
// NCHW. The other inputs are per channel values (1D, size C).
|
||||
//
|
||||
// In implicit batch mode, it is not possible to add any of the 1D
|
||||
// inputs as a tensor: the first dim is always treated as batch dim in
|
||||
// implicit batch mode, and that has to agree for all tensors. We have
|
||||
// two input tensors with shapes NCHW and C and in general N != C.
|
||||
// The converter already picked up N from the fist input, and reports
|
||||
// an error when we try to add any other tensors with not matching
|
||||
// first dim.
|
||||
//
|
||||
// This restriction does not apply in explicit batch mode: the tensors
|
||||
// can have different first dim. The converter still expects that only
|
||||
// the first arg is a tensor. TODO(tfeher) Check if one can relax this
|
||||
// restriction.
|
||||
Status expected_status =
|
||||
(i != 0 && trt_mode == TrtTestMode::kImplicitBatch)
|
||||
? errors::InvalidArgument(
|
||||
StrCat("Batch size doesn't match for tensor ",
|
||||
node_input[i].name,
|
||||
": Provided batch size does not match "
|
||||
"converter batch size: 3 vs 2"))
|
||||
: Status::OK();
|
||||
std::vector<int> partial_input_shape;
|
||||
if (i == 0 && trt_mode == TrtTestMode::kDynamicShape &&
|
||||
!p.keep_channel_unknown) {
|
||||
// keep channel dim static (known)
|
||||
partial_input_shape.resize(4, -1);
|
||||
partial_input_shape[1] = node_input[i].dims[1];
|
||||
}
|
||||
AddTestTensor(node_input[i].name, node_input[i].dims, tf_type,
|
||||
node_input[i].val, partial_input_shape,
|
||||
expected_status);
|
||||
|
||||
} else {
|
||||
AddTestWeights(node_input[i].name, node_input[i].dims,
|
||||
node_input[i].val, tf_type);
|
||||
}
|
||||
}
|
||||
TestOpConverter("my_batchnorm", node_def, node_input[0].dims,
|
||||
p.conversion_status, Status::OK(),
|
||||
ArrayFloatNear(expected_output));
|
||||
}
|
||||
}
|
||||
} // namespace convert
|
||||
|
||||
TEST_P(OpConverterTest1, ConvertTranspose) {
|
||||
// Get the NodeDef for Transpose.
|
||||
Scope s = Scope::NewRootScope();
|
||||
|
@ -711,15 +711,15 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
std::unordered_set<string> unsupported_ops;
|
||||
int num_unsupported_ops = 0;
|
||||
|
||||
// Getting the operations blacklisted for conversion
|
||||
string tftrt_op_blacklist_str;
|
||||
// Getting the operations denylisted for conversion
|
||||
string tftrt_op_denylist_str;
|
||||
TF_CHECK_OK(
|
||||
ReadStringFromEnvVar("TF_TRT_OP_BLACKLIST", "", &tftrt_op_blacklist_str));
|
||||
ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", "", &tftrt_op_denylist_str));
|
||||
|
||||
auto tftrt_op_blacklist = gtl::FlatSet<string>{}; // non-absl ok
|
||||
auto tftrt_op_denylist = gtl::FlatSet<string>{}; // non-absl ok
|
||||
|
||||
for (const auto& x : str_util::Split(tftrt_op_blacklist_str, ",")) {
|
||||
tftrt_op_blacklist.insert(x);
|
||||
for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) {
|
||||
tftrt_op_denylist.insert(x);
|
||||
}
|
||||
|
||||
// Parsing each node of the graph
|
||||
@ -761,13 +761,13 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
const Status status = candidate_fn(node->tf_node());
|
||||
if (!status.ok()) {
|
||||
exclude_node(status.error_message());
|
||||
} else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) {
|
||||
} else if (tftrt_op_denylist.count(node->tf_node()->type_string())) {
|
||||
// WARNING verbosity since the user explicitly requests this behavior.
|
||||
LOG_WARNING_WITH_PREFIX
|
||||
<< "Blacklisted as TF-TRT candidate, "
|
||||
<< "Denylisted as TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
<< "(Op name: " << node->name() << ")";
|
||||
exclude_node("Blacklisted with the env var TF_TRT_OP_BLACKLIST");
|
||||
exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
|
||||
} else {
|
||||
VLOG(2) << "Accepted as a TF-TRT candidate, "
|
||||
<< "(Op type: " << node->tf_node()->type_string() << "), "
|
||||
|
@ -535,10 +535,10 @@ static void AllocateFlags() {
|
||||
flag_values->xla_gpu_force_conv_nchw(),
|
||||
"For cuDNN convolutions, always NCHW layouts."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_gpu_algorithm_blacklist_path",
|
||||
string_setter_for(&DebugOptions::set_xla_gpu_algorithm_blacklist_path),
|
||||
flag_values->xla_gpu_algorithm_blacklist_path(),
|
||||
"An AlgorithmBlacklist text proto file as a blacklist of convolutions to "
|
||||
"xla_gpu_algorithm_denylist_path",
|
||||
string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path),
|
||||
flag_values->xla_gpu_algorithm_denylist_path(),
|
||||
"An AlgorithmDenylist text proto file as a denylist of convolutions to "
|
||||
"avoid to use."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_gpu_deterministic_reductions",
|
||||
|
@ -688,9 +688,7 @@ StatusOr<bool> RewriteDynamicConcat(
|
||||
dynamic_size));
|
||||
}
|
||||
}
|
||||
for (HloInstruction* user : prev_users) {
|
||||
TF_RETURN_IF_ERROR(concat->ReplaceUseWith(user, rewritten_concat));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat));
|
||||
TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
|
||||
concat, rewritten_concat, {}));
|
||||
return true;
|
||||
|
@ -83,8 +83,8 @@ class DynamicPadderTest : public HloTestBase {
|
||||
return module;
|
||||
}
|
||||
|
||||
StatusOr<bool> RunPadder() {
|
||||
DynamicPadder padder(/*slice_dynamic_output=*/true,
|
||||
StatusOr<bool> RunPadder(bool slice_dynamic_output = false) {
|
||||
DynamicPadder padder(/*slice_dynamic_output=*/slice_dynamic_output,
|
||||
CustomCallDynamicDimensionInference,
|
||||
OpHasDynamismSupport);
|
||||
return padder.Run(module_.get());
|
||||
@ -162,7 +162,7 @@ ENTRY main {
|
||||
|
||||
module_ = GetHloModule(hlo_text);
|
||||
|
||||
TF_ASSERT_OK(RunPadder().status());
|
||||
TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
|
||||
// After rewrite, we should have :
|
||||
//
|
||||
// param
|
||||
@ -218,7 +218,7 @@ ENTRY main {
|
||||
|
||||
module_ = GetHloModule(hlo_text);
|
||||
|
||||
TF_ASSERT_OK(RunPadder().status());
|
||||
TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status());
|
||||
// After rewrite, we should have :
|
||||
//
|
||||
// param
|
||||
@ -654,26 +654,16 @@ XLA_TEST_F(ExecutionTest, DynamicConcat) {
|
||||
const string hlo_text = R"(
|
||||
HloModule DynamicConcat
|
||||
|
||||
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
lhs = s32[] parameter(0)
|
||||
rhs = s32[] parameter(1)
|
||||
ROOT add = s32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param_0 = s32[3] parameter(0)
|
||||
param_1 = s32[3] parameter(1)
|
||||
param_2 = s32[3] parameter(2)
|
||||
size = s32[] constant(2)
|
||||
param_padded_0 = s32[3] set-dimension-size(param_0, size), dimensions={0}
|
||||
param_padded_2 = s32[3] set-dimension-size(param_2, size), dimensions={0}
|
||||
%concatenate = s32[9]
|
||||
concatenate(s32[3] param_padded_0, s32[3] param_1, s32[3] param_padded_2),
|
||||
param_padded_0 = s32[<=3] set-dimension-size(param_0, size), dimensions={0}
|
||||
param_padded_2 = s32[<=3] set-dimension-size(param_2, size), dimensions={0}
|
||||
ROOT %concatenate = s32[9]
|
||||
concatenate(s32[<=3] param_padded_0, s32[<=3] param_1, s32[<=3] param_padded_2),
|
||||
dimensions={0}
|
||||
init = s32[] constant(0)
|
||||
ROOT reduce = s32[] reduce(concatenate, init),
|
||||
dimensions={0},
|
||||
to_apply=update_s32
|
||||
}
|
||||
)";
|
||||
|
||||
@ -686,10 +676,10 @@ ENTRY main {
|
||||
LiteralUtil::CreateR1<int32>({6, 7, -1}); // Dynamic operand.
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result =
|
||||
PadAndExecute(std::move(module), {&operand_0, &operand_1, &operand_2});
|
||||
|
||||
Literal expected = LiteralUtil::CreateR0<int32>(28);
|
||||
Literal result = PadAndExecute(std::move(module),
|
||||
{&operand_0, &operand_1, &operand_2}, false);
|
||||
result.SetDynamicSize(0, 7);
|
||||
Literal expected = LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6, 7});
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
@ -170,7 +170,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
@ -1676,7 +1675,7 @@ cc_library(
|
||||
tf_cc_test(
|
||||
name = "hlo_algorithm_blacklist_test",
|
||||
srcs = ["hlo_algorithm_blacklist_test.cc"],
|
||||
data = ["data/hlo_algorithm_blacklist.pbtxt"],
|
||||
data = ["data/hlo_algorithm_denylist.pbtxt"],
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":hlo_algorithm_blacklist",
|
||||
|
@ -15,19 +15,19 @@ message ConvInstructionLog {
|
||||
repeated uint64 operand_addresses = 4;
|
||||
}
|
||||
|
||||
message BlacklistedAlgorithm {
|
||||
message DenylistedAlgorithm {
|
||||
int64 id = 1;
|
||||
bool tensor_ops = 2;
|
||||
}
|
||||
|
||||
message AlgorithmBlacklistEntry {
|
||||
message AlgorithmDenylistEntry {
|
||||
string hlo = 1;
|
||||
tensorflow.ComputeCapability cc = 2;
|
||||
tensorflow.CudnnVersion cudnn_version = 3;
|
||||
string blas_version = 5;
|
||||
repeated BlacklistedAlgorithm algos = 4;
|
||||
repeated DenylistedAlgorithm algos = 4;
|
||||
}
|
||||
|
||||
message AlgorithmBlacklist {
|
||||
repeated AlgorithmBlacklistEntry entries = 1;
|
||||
message AlgorithmDenylist {
|
||||
repeated AlgorithmDenylistEntry entries = 1;
|
||||
}
|
||||
|
@ -438,10 +438,9 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
||||
(void)blas->GetVersion(&blas_version);
|
||||
}
|
||||
|
||||
absl::Span<const AlgorithmDesc> blacklisted_algos =
|
||||
GetBlacklistedConvAlgorithms(GetComputeCapability(stream_exec_),
|
||||
GetCudnnVersion(stream_exec_), blas_version,
|
||||
canonical_hlo);
|
||||
absl::Span<const AlgorithmDesc> disabled_algos = GetDisabledConvAlgorithms(
|
||||
GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_),
|
||||
blas_version, canonical_hlo);
|
||||
|
||||
for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL(
|
||||
@ -449,7 +448,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
||||
AlgorithmToString(alg)),
|
||||
2);
|
||||
|
||||
if (absl::c_linear_search(blacklisted_algos, alg)) {
|
||||
if (absl::c_linear_search(disabled_algos, alg)) {
|
||||
LOG(INFO) << "Omitted potentially buggy algorithm "
|
||||
<< AlgorithmToString(alg) << " for conv " << instr->ToString();
|
||||
continue;
|
||||
@ -503,7 +502,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
||||
|
||||
if (!input_output_allocator_redzone_clear ||
|
||||
!scratch_allocator_redzone_clear) {
|
||||
AlgorithmBlacklist proto;
|
||||
AlgorithmDenylist proto;
|
||||
auto entry = proto.add_entries();
|
||||
entry->set_hlo(canonical_hlo);
|
||||
*entry->mutable_cc() = GetComputeCapability(stream_exec_);
|
||||
@ -513,13 +512,12 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
||||
algo->set_id(alg.algo_id());
|
||||
algo->set_tensor_ops(alg.tensor_ops_enabled());
|
||||
|
||||
LOG(ERROR)
|
||||
<< "To blacklist this algorithm for this convolution, "
|
||||
"copy-paste the following "
|
||||
"proto to the blacklist file pointed by XLA_FLAGS "
|
||||
"--xla_gpu_algorithm_blacklist_path="
|
||||
<< GetDebugOptionsFromFlags().xla_gpu_algorithm_blacklist_path()
|
||||
<< " : " << proto.ShortDebugString();
|
||||
LOG(ERROR) << "To denylist this algorithm for this convolution, "
|
||||
"copy-paste the following "
|
||||
"proto to the denylist file pointed by XLA_FLAGS "
|
||||
"--xla_gpu_algorithm_denylist_path="
|
||||
<< GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
|
||||
<< " : " << proto.ShortDebugString();
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
constexpr char kDefaultBlacklist[] = R"pb(
|
||||
constexpr char kDefaultDenylist[] = R"pb(
|
||||
entries {
|
||||
hlo: "(f32[4,32,32,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[4,32,32,32]{2,1,3,0}, f32[5,5,32,32]{1,0,2,3}), window={size=5x5 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\""
|
||||
cc { major: 7 }
|
||||
@ -41,28 +41,26 @@ constexpr char kDefaultBlacklist[] = R"pb(
|
||||
}
|
||||
)pb";
|
||||
|
||||
absl::Span<const stream_executor::dnn::AlgorithmDesc>
|
||||
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
||||
tensorflow::CudnnVersion cudnn_version,
|
||||
const std::string& blas_version,
|
||||
const std::string& hlo) {
|
||||
absl::Span<const stream_executor::dnn::AlgorithmDesc> GetDisabledConvAlgorithms(
|
||||
tensorflow::ComputeCapability cc, tensorflow::CudnnVersion cudnn_version,
|
||||
const std::string& blas_version, const std::string& hlo) {
|
||||
// Key is the tuple of canonicalized hlo, compute capability major/minor,
|
||||
// cudnn version major/minor/patch, blas version.
|
||||
using MapType = absl::flat_hash_map<
|
||||
std::tuple<std::string, int, int, int, int, int, std::string>,
|
||||
std::vector<stream_executor::dnn::AlgorithmDesc>>;
|
||||
|
||||
static MapType* blacklist = [] {
|
||||
static MapType* denylist = [] {
|
||||
MapType* list = new MapType();
|
||||
AlgorithmBlacklist proto;
|
||||
AlgorithmDenylist proto;
|
||||
std::string file_path =
|
||||
GetDebugOptionsFromFlags().xla_gpu_algorithm_blacklist_path();
|
||||
GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path();
|
||||
if (!file_path.empty()) {
|
||||
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(),
|
||||
file_path, &proto));
|
||||
} else {
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
std::string(kDefaultBlacklist), &proto));
|
||||
std::string(kDefaultDenylist), &proto));
|
||||
}
|
||||
for (const auto& entry : proto.entries()) {
|
||||
for (const auto& algo : entry.algos()) {
|
||||
@ -77,10 +75,10 @@ GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
||||
return list;
|
||||
}();
|
||||
|
||||
auto iter = blacklist->find(std::make_tuple(
|
||||
auto iter = denylist->find(std::make_tuple(
|
||||
hlo, cc.major(), cc.minor(), cudnn_version.major(), cudnn_version.minor(),
|
||||
cudnn_version.patch(), std::string(blas_version)));
|
||||
if (iter != blacklist->end()) {
|
||||
if (iter != denylist->end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return {};
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -24,13 +24,11 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
absl::Span<const stream_executor::dnn::AlgorithmDesc>
|
||||
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
||||
tensorflow::CudnnVersion cudnn_version,
|
||||
const std::string& blas_version,
|
||||
const std::string& hlo);
|
||||
absl::Span<const stream_executor::dnn::AlgorithmDesc> GetDisabledConvAlgorithms(
|
||||
tensorflow::ComputeCapability cc, tensorflow::CudnnVersion cudnn_version,
|
||||
const std::string& blas_version, const std::string& hlo);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_
|
||||
|
@ -26,22 +26,22 @@ namespace xla {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
class BlacklistTest : public testing::Test {
|
||||
class DenylistTest : public testing::Test {
|
||||
protected:
|
||||
BlacklistTest() {
|
||||
DenylistTest() {
|
||||
tensorflow::setenv(
|
||||
"XLA_FLAGS",
|
||||
absl::StrCat(
|
||||
"--xla_gpu_algorithm_blacklist_path=",
|
||||
"--xla_gpu_algorithm_denylist_path=",
|
||||
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
|
||||
"tensorflow", "compiler", "xla", "service", "gpu", "data",
|
||||
"hlo_algorithm_blacklist.pbtxt")))
|
||||
"hlo_algorithm_denylist.pbtxt")))
|
||||
.data(),
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(BlacklistTest, DefaultTest) {
|
||||
TEST_F(DenylistTest, DefaultTest) {
|
||||
tensorflow::ComputeCapability cc;
|
||||
cc.set_major(7);
|
||||
cc.set_minor(0);
|
||||
@ -49,7 +49,7 @@ TEST_F(BlacklistTest, DefaultTest) {
|
||||
cudnn_version.set_major(7);
|
||||
cudnn_version.set_minor(6);
|
||||
cudnn_version.set_patch(2);
|
||||
auto list = GetBlacklistedConvAlgorithms(
|
||||
auto list = GetDisabledConvAlgorithms(
|
||||
cc, cudnn_version, /*blas_version=*/"9000",
|
||||
R"((f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}")");
|
||||
ASSERT_EQ(4, list.size());
|
||||
@ -59,7 +59,7 @@ TEST_F(BlacklistTest, DefaultTest) {
|
||||
EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(1, true), list[3]);
|
||||
}
|
||||
|
||||
TEST_F(BlacklistTest, NegativeTest) {
|
||||
TEST_F(DenylistTest, NegativeTest) {
|
||||
tensorflow::ComputeCapability cc;
|
||||
cc.set_major(7);
|
||||
cc.set_minor(0);
|
||||
@ -68,7 +68,7 @@ TEST_F(BlacklistTest, NegativeTest) {
|
||||
cudnn_version.set_minor(6);
|
||||
cudnn_version.set_minor(2);
|
||||
auto list =
|
||||
GetBlacklistedConvAlgorithms(cc, cudnn_version, "9000", R"(invalid hlo)");
|
||||
GetDisabledConvAlgorithms(cc, cudnn_version, "9000", R"(invalid hlo)");
|
||||
ASSERT_EQ(0, list.size());
|
||||
}
|
||||
|
||||
|
@ -231,7 +231,6 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
|
||||
<< " of " << hlo.ToString();
|
||||
llvm_ir::IrArray ir_array(base_ptr,
|
||||
ShapeUtil::GetSubshape(hlo.shape(), shape_index));
|
||||
alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index);
|
||||
|
||||
// The GPU backend emits one kernel per top-level HLO, and LLVM views
|
||||
// execution of one kernel as the "whole program" executed on the GPU.
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
|
||||
namespace xla {
|
||||
@ -42,8 +41,7 @@ class HloToIrBindings {
|
||||
: buffer_assignment_(buffer_assignment),
|
||||
is_nested_(is_nested),
|
||||
b_(b),
|
||||
module_(llvm_module),
|
||||
alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {}
|
||||
module_(llvm_module) {}
|
||||
|
||||
void EmitBasePointersForHlos(
|
||||
absl::Span<const HloInstruction* const> io_hlos,
|
||||
@ -116,8 +114,6 @@ class HloToIrBindings {
|
||||
|
||||
// The address of the memory block that contains all temporary buffers.
|
||||
llvm::Value* temp_buffer_base_ = nullptr;
|
||||
|
||||
llvm_ir::AliasAnalysis alias_analysis_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -1747,6 +1747,25 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
|
||||
auto buffers_it = non_constant_buffers.begin();
|
||||
for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
|
||||
kernel_args[*buffers_it] = arg_it;
|
||||
|
||||
// Annotate all allocations with LLVM's `noalias`.
|
||||
// There are three kinds of allocations:
|
||||
// * Read-only allocations, aka input parameters that are not aliased with
|
||||
// outputs.
|
||||
// * Read-write allocations, including all output buffers, some of which
|
||||
// may alias with input HLO parameters, but aliased HLO buffers are always
|
||||
// assigned with the same allocation.
|
||||
// * The temp buffer.
|
||||
//
|
||||
// Read-only allocations may overlap with each other, but since they are
|
||||
// not mutated, they can always be annotated with `noalias` per LLVM
|
||||
// semantics.
|
||||
//
|
||||
// Read-write allocations and the temp buffer don't overlap with any
|
||||
// allocations, therefore they can also be annotated with `noalias`.
|
||||
kernel->addParamAttr(
|
||||
arg_it->getArgNo(),
|
||||
llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ ENTRY main {
|
||||
)";
|
||||
|
||||
CompileAndVerifyIr(hlo_string, R"(
|
||||
CHECK: @fusion(i8* align 64 dereferenceable(600) %alloc0, i8* align 16 dereferenceable(400) %alloc1, i8* align 64 dereferenceable(864) %temp_buf)
|
||||
CHECK: @fusion(i8* noalias align 64 dereferenceable(600) %alloc0, i8* noalias align 16 dereferenceable(400) %alloc1, i8* noalias align 64 dereferenceable(864) %temp_buf)
|
||||
)");
|
||||
}
|
||||
|
||||
|
@ -51,16 +51,9 @@ TEST_F(GpuNoAliasTest, Concat) {
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
CompileAndVerifyIr(std::move(hlo_module),
|
||||
R"(
|
||||
; CHECK: %[[x_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %x{{.*}}, i32 0
|
||||
; CHECK: load float, float* %[[x_gep]], {{.*}}, !noalias ![[param_noalias:.*]]
|
||||
; CHECK: %[[y_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %y{{.*}}, i32 0
|
||||
; CHECK: load float, float* %[[y_gep]], {{.*}}, !noalias ![[param_noalias]]
|
||||
; CHECK: %[[result_ptr:.*]] = bitcast [2 x [6 x float]]* %fusion{{.*}} to float*
|
||||
; CHECK: %[[result_gep:.*]] = getelementptr inbounds float, float* %[[result_ptr]]
|
||||
; CHECK: store float {{.*}}, float* %[[result_gep]], align 4, !alias.scope ![[param_noalias]]
|
||||
; CHECK: ![[param_noalias]] = !{![[retval_buffer:.*]]}
|
||||
)",
|
||||
R"(CHECK-LABEL: define void @fusion
|
||||
CHECK-SAME: i8* noalias align {{[0-9]*}} dereferenceable({{[0-9]*}}) %[[OUTPUT_ALLOC:[a-z0-9]*]]
|
||||
CHECK: %fusion.raw = {{.*}} %[[OUTPUT_ALLOC]])",
|
||||
/*match_optimized_ir=*/false);
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: hlo_to_llvm_ir %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
@ -26,7 +26,7 @@
|
||||
// CHECK: ret void
|
||||
// CHECK: scatter_TensorFlowScatterV1.in_bounds-true: ; preds = %[[VAL_24]]
|
||||
// CHECK: %[[VAL_25:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_8]], i32 0, i32 %[[VAL_19]]
|
||||
// CHECK: %[[VAL_26:.*]] = load i32, i32* %[[VAL_25]], align 4, !invariant.load !4, !noalias !5
|
||||
// CHECK: %[[VAL_26:.*]] = load i32, i32* %[[VAL_25]], align 4, !invariant.load !4
|
||||
// CHECK: %[[VAL_27:.*]] = add i32 0, %[[VAL_26]]
|
||||
// CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_26]], 3
|
||||
// CHECK: %[[VAL_29:.*]] = and i1 true, %[[VAL_28]]
|
||||
@ -37,7 +37,7 @@
|
||||
// CHECK: %[[VAL_31:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_18]]
|
||||
// CHECK: %[[VAL_33:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_11]] to i32*
|
||||
// CHECK: %[[VAL_34:.*]] = getelementptr inbounds i32, i32* %[[VAL_33]], i32 %[[VAL_15]]
|
||||
// CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4, !noalias !5
|
||||
// CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4
|
||||
// CHECK: store i32 %[[VAL_35]], i32* %[[VAL_32]], align 4
|
||||
// CHECK: %[[VAL_36:.*]] = load i32, i32* %[[VAL_32]], align 4
|
||||
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
|
||||
@ -48,9 +48,6 @@
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{i32 0, i32 6}
|
||||
// CHECK: !4 = !{}
|
||||
// CHECK: !5 = !{!6}
|
||||
// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7}
|
||||
// CHECK: !7 = !{!"XLA global AA domain"}
|
||||
|
||||
|
||||
HloModule TensorFlowScatterV1
|
||||
@ -75,7 +72,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* align 64 dereferenceable(4) %alloc0, i8* align 16 dereferenceable(4) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
|
||||
@ -101,7 +98,7 @@ ENTRY main {
|
||||
// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_59]], %[[VAL_55]]
|
||||
// CHECK: br label %[[VAL_56]]
|
||||
// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_55]]
|
||||
// CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3, !noalias !4
|
||||
// CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3
|
||||
// CHECK: store i32 %[[VAL_61]], i32* %[[VAL_60]], align 4
|
||||
// CHECK: %[[VAL_62:.*]] = load i32, i32* %[[VAL_60]], align 4
|
||||
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
|
||||
@ -111,9 +108,6 @@ ENTRY main {
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{}
|
||||
// CHECK: !4 = !{!5}
|
||||
// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:4}", !6}
|
||||
// CHECK: !6 = !{!"XLA global AA domain"}
|
||||
|
||||
HloModule ScatterIntoScalar
|
||||
|
||||
@ -137,7 +131,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
|
||||
@ -164,7 +158,7 @@ ENTRY main {
|
||||
// CHECK: ret void
|
||||
// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-true: ; preds = %[[VAL_89]]
|
||||
// CHECK: %[[VAL_90:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_73]], i32 0, i32 %[[VAL_84]]
|
||||
// CHECK: %[[VAL_91:.*]] = load i32, i32* %[[VAL_90]], align 4, !invariant.load !4, !noalias !5
|
||||
// CHECK: %[[VAL_91:.*]] = load i32, i32* %[[VAL_90]], align 4, !invariant.load !4
|
||||
// CHECK: %[[VAL_92:.*]] = add i32 0, %[[VAL_91]]
|
||||
// CHECK: %[[VAL_93:.*]] = icmp ult i32 %[[VAL_91]], 3
|
||||
// CHECK: %[[VAL_94:.*]] = and i1 true, %[[VAL_93]]
|
||||
@ -175,7 +169,7 @@ ENTRY main {
|
||||
// CHECK: %[[VAL_97:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_67]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_83]]
|
||||
// CHECK: %[[VAL_99:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_76]] to i32*
|
||||
// CHECK: %[[VAL_100:.*]] = getelementptr inbounds i32, i32* %[[VAL_99]], i32 %[[VAL_80]]
|
||||
// CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4, !noalias !5
|
||||
// CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4
|
||||
// CHECK: store i32 %[[VAL_101]], i32* %[[VAL_98]], align 4
|
||||
// CHECK: %[[VAL_102:.*]] = load i32, i32* %[[VAL_98]], align 4
|
||||
// CHECK: %[[VAL_103:.*]] = load i32, i32* %[[VAL_97]], align 4
|
||||
@ -199,15 +193,6 @@ ENTRY main {
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{i32 0, i32 6}
|
||||
// CHECK: !4 = !{}
|
||||
// CHECK: !5 = !{!6}
|
||||
// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7}
|
||||
// CHECK: !7 = !{!"XLA global AA domain"}
|
||||
// CHECK: !8 = !{!9}
|
||||
// CHECK: !9 = !{!"buffer: {index:4, offset:0, size:4}", !7}
|
||||
// CHECK: !10 = !{!11}
|
||||
// CHECK: !11 = !{!"buffer: {index:6, offset:0, size:4}", !7}
|
||||
// CHECK: !12 = !{!13}
|
||||
// CHECK: !13 = !{!"buffer: {index:5, offset:0, size:4}", !7}
|
||||
|
||||
HloModule TensorFlowScatter_Mul
|
||||
|
||||
@ -231,7 +216,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* align 64 dereferenceable(16) %alloc0, i8* align 16 dereferenceable(16) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 dereferenceable(4) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
|
||||
@ -253,7 +238,7 @@ ENTRY main {
|
||||
// CHECK: scatter_ScalarUpdate.in_bounds-after: ; preds = %[[VAL_138:.*]], %[[VAL_139:.*]]
|
||||
// CHECK: ret void
|
||||
// CHECK: scatter_ScalarUpdate.in_bounds-true: ; preds = %[[VAL_139]]
|
||||
// CHECK: %[[VAL_140:.*]] = load i32, i32* %[[VAL_126]], align 4, !invariant.load !3, !noalias !4
|
||||
// CHECK: %[[VAL_140:.*]] = load i32, i32* %[[VAL_126]], align 4, !invariant.load !3
|
||||
// CHECK: %[[VAL_141:.*]] = add i32 0, %[[VAL_140]]
|
||||
// CHECK: %[[VAL_142:.*]] = icmp ult i32 %[[VAL_140]], 4
|
||||
// CHECK: %[[VAL_143:.*]] = and i1 true, %[[VAL_142]]
|
||||
@ -262,7 +247,7 @@ ENTRY main {
|
||||
// CHECK: br label %[[VAL_137]]
|
||||
// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_136]]
|
||||
// CHECK: %[[VAL_145:.*]] = getelementptr inbounds [4 x i32], [4 x i32]* %[[VAL_120]], i32 0, i32 %[[VAL_141]]
|
||||
// CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3, !noalias !4
|
||||
// CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3
|
||||
// CHECK: store i32 %[[VAL_147]], i32* %[[VAL_146]], align 4
|
||||
// CHECK: %[[VAL_148:.*]] = load i32, i32* %[[VAL_146]], align 4
|
||||
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
|
||||
@ -272,9 +257,6 @@ ENTRY main {
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{}
|
||||
// CHECK: !4 = !{!5}
|
||||
// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:16}", !6}
|
||||
// CHECK: !6 = !{!"XLA global AA domain"}
|
||||
|
||||
HloModule ScalarUpdate
|
||||
|
||||
|
@ -2189,6 +2189,27 @@ Status HloInstruction::ReplaceOperandWithDifferentShape(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
|
||||
HloInstruction* new_producer) {
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
|
||||
<< shape() << " is not compatible with " << new_producer->shape();
|
||||
return ReplaceAllUsesWithDifferentShape(users, new_producer);
|
||||
}
|
||||
|
||||
Status HloInstruction::ReplaceAllUsesWithDifferentShape(
|
||||
absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
|
||||
for (HloInstruction* user : users) {
|
||||
TF_RETURN_IF_ERROR(ReplaceUseWith(user, new_producer));
|
||||
}
|
||||
|
||||
if (parent_ && parent_->root_instruction() == this) {
|
||||
parent_->set_root_instruction(new_producer,
|
||||
/*accept_different_shape=*/true);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
|
||||
|
@ -1201,6 +1201,12 @@ class HloInstruction {
|
||||
// Same as ReplaceAllUsesWith, but new_producer can have a different shape.
|
||||
Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer);
|
||||
|
||||
// Same as ReplaceAllUsesWith, but only replace given set of users.
|
||||
Status ReplaceUsesWith(absl::Span<HloInstruction* const> users,
|
||||
HloInstruction* new_producer);
|
||||
Status ReplaceAllUsesWithDifferentShape(
|
||||
absl::Span<HloInstruction* const> users, HloInstruction* new_producer);
|
||||
|
||||
// Performs a postorder DFS visit using this node as the root. If
|
||||
// call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
|
||||
// complete. If ignore_control_predecessors is true, instructions only
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
@ -650,30 +651,28 @@ bool CompareComputationsByContent(HloComputation* a, HloComputation* b) {
|
||||
} // anonymous namespace
|
||||
|
||||
std::vector<HloComputation*> HloModule::MakeComputationSorted() const {
|
||||
std::vector<HloComputation*> result;
|
||||
result.reserve(computations_.size());
|
||||
for (const auto& computation : computations_) {
|
||||
result.push_back(computation.get());
|
||||
std::vector<HloComputation*> result = MakeComputationPostOrder();
|
||||
if (config().content_aware_computation_sorting()) {
|
||||
absl::c_sort(result, CompareComputationsByContent);
|
||||
}
|
||||
std::sort(result.begin(), result.end(), CompareComputationsByContent);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
|
||||
std::vector<HloComputation*> result;
|
||||
for (auto* c : computations()) {
|
||||
if (c->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
result.push_back(c);
|
||||
}
|
||||
std::vector<HloComputation*> result = MakeComputationPostOrder();
|
||||
result.erase(std::remove_if(
|
||||
result.begin(), result.end(),
|
||||
[](HloComputation* c) { return c->IsFusionComputation(); }),
|
||||
result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted()
|
||||
const {
|
||||
auto result = MakeNonfusionComputations();
|
||||
std::sort(result.begin(), result.end(), CompareComputationsByContent);
|
||||
if (config().content_aware_computation_sorting()) {
|
||||
absl::c_sort(result, CompareComputationsByContent);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -188,6 +188,14 @@ class HloModuleConfig {
|
||||
alias_passthrough_params_ = alias_passthrough_params;
|
||||
}
|
||||
|
||||
bool content_aware_computation_sorting() const {
|
||||
return content_aware_computation_sorting_;
|
||||
}
|
||||
void set_content_aware_computation_sorting(
|
||||
bool content_aware_computation_sorting) {
|
||||
content_aware_computation_sorting_ = content_aware_computation_sorting;
|
||||
}
|
||||
|
||||
FusionConfigCollection fusion_config_collection() const {
|
||||
return fusion_config_collection_;
|
||||
}
|
||||
@ -251,6 +259,8 @@ class HloModuleConfig {
|
||||
|
||||
bool alias_passthrough_params_ = false;
|
||||
|
||||
bool content_aware_computation_sorting_ = false;
|
||||
|
||||
FusionConfigCollection fusion_config_collection_ =
|
||||
FusionConfigCollection::kOff;
|
||||
|
||||
|
@ -121,9 +121,9 @@ struct Item {
|
||||
bool placed = false;
|
||||
|
||||
// To avoid an infinite loop rematerializing the same set of
|
||||
// instructions ad infinitum, keep a blacklist of instructions
|
||||
// instructions ad infinitum, keep a denylist of instructions
|
||||
// which should not be rematerialized.
|
||||
bool blacklisted = false;
|
||||
bool denylisted = false;
|
||||
|
||||
// The buffers defined by this instruction.
|
||||
BufferIdList buffers_defined;
|
||||
@ -292,8 +292,8 @@ class InstructionList {
|
||||
InsertBeforeInstructions(to_insert, {max_position_item->next});
|
||||
}
|
||||
|
||||
void Blacklist(const HloInstruction* inst) {
|
||||
GetItem(inst)->blacklisted = true;
|
||||
void Denylist(const HloInstruction* inst) {
|
||||
GetItem(inst)->denylisted = true;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1158,13 +1158,13 @@ std::vector<Item*> GetInitialBlock(const InstructionList& instruction_list,
|
||||
return item_block;
|
||||
}
|
||||
|
||||
// Returns whether any instruction in 'block' is blacklisted or
|
||||
// Returns whether any instruction in 'block' is denylisted or
|
||||
// non-rematerializable.
|
||||
bool AnyBlacklistedOrNonRematerializable(
|
||||
bool AnyDenylistedOrNonRematerializable(
|
||||
const std::vector<Item*>& block,
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
|
||||
for (auto* item : block) {
|
||||
if (item->blacklisted) {
|
||||
if (item->denylisted) {
|
||||
return true;
|
||||
}
|
||||
if (!CanBeRematerialized(item->instruction, rematerializable_map)) {
|
||||
@ -1195,10 +1195,10 @@ MemoryUsageTracker::PickRematerializationCandidates(
|
||||
// instructions.
|
||||
break;
|
||||
}
|
||||
// If any item in the starting block are blacklisted or non-rematable, then
|
||||
// If any item in the starting block are denylisted or non-rematable, then
|
||||
// break and move on to next start_item (we can actually move to the last
|
||||
// invalid item in this block, but let's ignore that optimization for now).
|
||||
if (AnyBlacklistedOrNonRematerializable(block, rematerializable_map)) {
|
||||
if (AnyDenylistedOrNonRematerializable(block, rematerializable_map)) {
|
||||
continue;
|
||||
}
|
||||
while (block.size() <= max_block_size) {
|
||||
@ -1289,8 +1289,8 @@ MemoryUsageTracker::PickRematerializationCandidates(
|
||||
// Time to update the block to include the next instruction.
|
||||
auto* last_item = block[block.size() - 1];
|
||||
auto* next_item = instruction_list.next(last_item);
|
||||
if (next_item == nullptr || next_item->blacklisted ||
|
||||
!next_item->placed || next_item == in_progress_item_ ||
|
||||
if (next_item == nullptr || next_item->denylisted || !next_item->placed ||
|
||||
next_item == in_progress_item_ ||
|
||||
!CanBeRematerialized(next_item->instruction, rematerializable_map)) {
|
||||
break;
|
||||
}
|
||||
@ -1404,7 +1404,7 @@ StatusOr<int64> RematerializeInstructions(
|
||||
// instruction it was a copying of. Now 'remat' is a rematerialization
|
||||
// of 'best' and kills 'best'. Stop rematerializing this instruction
|
||||
// to avoid an infinite loop.
|
||||
instruction_list->Blacklist(remat);
|
||||
instruction_list->Denylist(remat);
|
||||
}
|
||||
remat_move_instructions->insert(remat);
|
||||
} else {
|
||||
@ -1460,8 +1460,8 @@ StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
}
|
||||
|
||||
instruction_list->Blacklist(compressed_item->instruction);
|
||||
instruction_list->Blacklist(uncompressed_item->instruction);
|
||||
instruction_list->Denylist(compressed_item->instruction);
|
||||
instruction_list->Denylist(uncompressed_item->instruction);
|
||||
|
||||
instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
|
||||
|
||||
@ -1583,7 +1583,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
// rematerialization is added to 'remat_move_instructions' (the
|
||||
// rematerialization is essentially a move). If the next rematerialization of
|
||||
// the instruction is also a move then the rematerialization is added to the
|
||||
// blacklist.
|
||||
// denylist.
|
||||
absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
|
||||
|
||||
// The map from instructions to their rematerializable status.
|
||||
|
@ -17,6 +17,8 @@ package_group(
|
||||
cc_library(
|
||||
name = "spmd_partitioner",
|
||||
srcs = [
|
||||
"convolution_handler.cc",
|
||||
"dot_handler.cc",
|
||||
"spmd_partitioner.cc",
|
||||
"spmd_partitioner_util.cc",
|
||||
],
|
||||
|
695
tensorflow/compiler/xla/service/spmd/convolution_handler.cc
Normal file
695
tensorflow/compiler/xla/service/spmd/convolution_handler.cc
Normal file
@ -0,0 +1,695 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/numbers.h"
|
||||
|
||||
namespace xla {
|
||||
namespace spmd {
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
|
||||
HloInstruction* hlo) {
|
||||
TF_RET_CHECK(hlo->opcode() == HloOpcode::kConvolution);
|
||||
|
||||
auto lhs = GetPartitionedHlo(hlo->operand(0));
|
||||
auto rhs = GetPartitionedHlo(hlo->operand(1));
|
||||
TF_RET_CHECK(!lhs.sharding().IsTileMaximal() &&
|
||||
!rhs.sharding().IsTileMaximal());
|
||||
|
||||
const auto& dnums = hlo->convolution_dimension_numbers();
|
||||
|
||||
// Check if the operand shardings are aligned. Also we currently don't
|
||||
// support partitioning non-spatial dimensions.
|
||||
std::vector<int64> rhs_to_lhs_indices(hlo->shape().rank());
|
||||
rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
|
||||
dnums.input_batch_dimension();
|
||||
rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
|
||||
dnums.input_feature_dimension();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
|
||||
dnums.input_spatial_dimensions(i);
|
||||
}
|
||||
std::vector<int64> lhs_to_rhs_indices(hlo->shape().rank());
|
||||
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
|
||||
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
|
||||
}
|
||||
|
||||
Window window = hlo->window();
|
||||
std::vector<int64> reversed_rhs_dims;
|
||||
for (int64 i = 0; i < window.dimensions_size(); ++i) {
|
||||
if (window.dimensions(i).window_reversal()) {
|
||||
reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i));
|
||||
}
|
||||
}
|
||||
if (!reversed_rhs_dims.empty()) {
|
||||
// Make the reversed dims left-padded to prepare for window reversal.
|
||||
auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims);
|
||||
if (left_padded_rhs == nullptr) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
left_padded_rhs->set_sharding(rhs.sharding());
|
||||
rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state());
|
||||
}
|
||||
// Consider window reversal when resharding RHS or LHS. Note: this will not
|
||||
// reverse the data in the shard. We use window reversal to do that.
|
||||
auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding(
|
||||
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices),
|
||||
reversed_rhs_dims);
|
||||
auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(
|
||||
hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims),
|
||||
lhs_to_rhs_indices);
|
||||
|
||||
auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
|
||||
const HloSharding& rhs_sharding) {
|
||||
return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) !=
|
||||
1 ||
|
||||
rhs_sharding.tile_assignment().dim(
|
||||
dnums.kernel_output_feature_dimension()) != 1;
|
||||
};
|
||||
|
||||
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) {
|
||||
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
|
||||
rhs = rhs.PadWithValue(zero, reversed_rhs_dims);
|
||||
} else {
|
||||
if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
lhs = lhs.PadWithValue(zero);
|
||||
rhs =
|
||||
rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
|
||||
}
|
||||
|
||||
// Reshard LHS by exchanging halo such that each shard computes the partial
|
||||
// sum of the full shape result, and add AllReduce.
|
||||
//
|
||||
// The size of halo on each dimension can be calculated from the projection
|
||||
// onto the LHS that each RHS shard i needs to read. RHS and LHS below refers
|
||||
// to the shard size of RHS and LHS, WC is the number of windows, and D is the
|
||||
// window dilation.
|
||||
//
|
||||
// * offset(i): RHS * D * i - low_padding
|
||||
// * limit(i): {(RHS - 1) * D + 1} * (i + 1) + (WC - 1) * stride - low_padding
|
||||
//
|
||||
// Since shard i has LHS of range [i * LHS, (i + 1) * LHS)
|
||||
// * left-halo: i * LHS - offset(i)
|
||||
// = (LHS - RHS) * i + low_padding
|
||||
// * right-halo: limit(i) - (i + 1) * LHS
|
||||
// = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding
|
||||
std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
|
||||
std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
|
||||
std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
|
||||
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
|
||||
int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
|
||||
auto wd = window.dimensions(i);
|
||||
if (wd.base_dilation() != 1) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
int64 lhs_shard_size =
|
||||
CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count);
|
||||
int64 rhs_shard_size =
|
||||
CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count);
|
||||
shard_counts[i] = shard_count;
|
||||
lhs_shard_sizes[i] = lhs_shard_size;
|
||||
rhs_shard_sizes[i] = rhs_shard_size;
|
||||
}
|
||||
|
||||
std::vector<OffsetCalculation> left_halo_size_functions(hlo->shape().rank());
|
||||
std::vector<OffsetCalculation> right_halo_size_functions(hlo->shape().rank());
|
||||
Window new_window = window;
|
||||
|
||||
auto partition_ordinals =
|
||||
MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_);
|
||||
HloInstruction* lhs_with_halo = lhs.hlo();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
|
||||
int64 lhs_shard_size = lhs_shard_sizes[i];
|
||||
int64 rhs_shard_size = rhs_shard_sizes[i];
|
||||
|
||||
if (shard_counts[i] == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Calculate the left and right halo sizes as described in the comments
|
||||
// above.
|
||||
auto wd = window.dimensions(i);
|
||||
int64 padding_low = wd.padding_low();
|
||||
int64 padding_high = wd.padding_high();
|
||||
int64 base = lhs.base_shape().dimensions(lhs_dimension);
|
||||
int64 window_count = 1 + (padding_low + padding_high + base -
|
||||
(1 + (wd.size() - 1) * wd.window_dilation())) /
|
||||
wd.stride();
|
||||
int64 rhs_shard_size_dilated =
|
||||
(rhs_shard_size - 1) * wd.window_dilation() + 1;
|
||||
|
||||
left_halo_size_functions[lhs_dimension] =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low,
|
||||
1));
|
||||
right_halo_size_functions[lhs_dimension] =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
rhs_shard_size_dilated - lhs_shard_size,
|
||||
rhs_shard_size_dilated - lhs_shard_size +
|
||||
wd.stride() * (window_count - 1) - padding_low,
|
||||
1));
|
||||
|
||||
// Exchange halo and concatenate.
|
||||
int64 dim = dnums.input_spatial_dimensions(i);
|
||||
int64 explicit_left_padding_on_full_shape = padding_low;
|
||||
int64 shard_size_with_halo =
|
||||
wd.stride() * (window_count - 1) + rhs_shard_size_dilated;
|
||||
|
||||
new_window.mutable_dimensions(i)->set_padding_low(0);
|
||||
new_window.mutable_dimensions(i)->set_padding_high(0);
|
||||
new_window.mutable_dimensions(i)->set_size(rhs_shard_size);
|
||||
|
||||
// offset_on_padded_shape and padded_full_shape_size are needed only if
|
||||
// we want to mask out-of-range values in ExchangeHaloAndGetValidData().
|
||||
// Since the default value for both the collective-permute is zero and
|
||||
// also we call PadWithValue() on both operands at the beginning, we
|
||||
// don't need to mask here.
|
||||
//
|
||||
// TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
|
||||
// if it's always safe.
|
||||
auto offset_on_padded_shape =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation());
|
||||
int64 padded_full_shape_size = 0;
|
||||
auto concat = ExchangeHaloAndGetValidData(
|
||||
lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim],
|
||||
right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
|
||||
padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(),
|
||||
offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), zero,
|
||||
partition_ordinals[dim], collective_ops_creator_, next_channel_id_, &b_,
|
||||
/*mask_invalid_region=*/false);
|
||||
if (!concat) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
lhs_with_halo = *concat;
|
||||
}
|
||||
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
auto conv = b_.AddInstruction(HloInstruction::CreateConvolve(
|
||||
hlo->shape(), lhs_with_halo, rhs.hlo(), hlo->feature_group_count(),
|
||||
hlo->batch_group_count(), new_window,
|
||||
hlo->convolution_dimension_numbers(), hlo->precision_config()));
|
||||
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
|
||||
NewChannel());
|
||||
ar->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(hlo->sharding())
|
||||
.hlo();
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
|
||||
auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo);
|
||||
if (dot_dnums) {
|
||||
// Use HandleDotHelper() for convs that are actually einsums.
|
||||
spmd::DotGeneralDimsMapping mapping;
|
||||
for (const auto& dims : dot_dnums->batch_dims) {
|
||||
mapping.batch_dims.emplace_back();
|
||||
mapping.batch_dims.back().lhs = dims.lhs;
|
||||
mapping.batch_dims.back().rhs = dims.rhs;
|
||||
mapping.batch_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->contracting_dims) {
|
||||
mapping.contracting_dims.emplace_back();
|
||||
mapping.contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.contracting_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->lhs_non_contracting_dims) {
|
||||
mapping.lhs_non_contracting_dims.emplace_back();
|
||||
mapping.lhs_non_contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.lhs_non_contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.lhs_non_contracting_dims.back().output = dims.output;
|
||||
}
|
||||
for (const auto& dims : dot_dnums->rhs_non_contracting_dims) {
|
||||
mapping.rhs_non_contracting_dims.emplace_back();
|
||||
mapping.rhs_non_contracting_dims.back().lhs = dims.lhs;
|
||||
mapping.rhs_non_contracting_dims.back().rhs = dims.rhs;
|
||||
mapping.rhs_non_contracting_dims.back().output = dims.output;
|
||||
}
|
||||
auto create_sharded_conv =
|
||||
[&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo,
|
||||
spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharded_conv,
|
||||
dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution(
|
||||
*hlo, *dot_dnums, lhs_hlo, rhs_hlo));
|
||||
return b->AddInstruction(std::move(sharded_conv));
|
||||
};
|
||||
return HandleDotHelper(hlo, mapping, create_sharded_conv);
|
||||
}
|
||||
|
||||
auto lhs = GetPartitionedHlo(hlo->operand(0));
|
||||
auto rhs = GetPartitionedHlo(hlo->operand(1));
|
||||
const HloSharding& sharding = hlo->sharding();
|
||||
const auto& dnums = hlo->convolution_dimension_numbers();
|
||||
std::vector<int64> rhs_to_lhs_indices(hlo->shape().rank());
|
||||
rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] =
|
||||
dnums.input_batch_dimension();
|
||||
rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] =
|
||||
dnums.input_feature_dimension();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] =
|
||||
dnums.input_spatial_dimensions(i);
|
||||
}
|
||||
std::vector<int64> lhs_to_rhs_indices(hlo->shape().rank());
|
||||
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
|
||||
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
|
||||
}
|
||||
auto aligned_rhs_sharding =
|
||||
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
|
||||
auto aligned_lhs_sharding =
|
||||
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
|
||||
|
||||
// Handling cases where all the partitioned dimensions are parallel
|
||||
// dimensions.
|
||||
int64 lhs_parallel_dim_partitions = 1;
|
||||
int64 rhs_parallel_dim_partitions = 1;
|
||||
std::vector<int64> parallel_spatial_dims;
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs_dim = dnums.input_spatial_dimensions(i);
|
||||
int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
|
||||
const auto& wd = hlo->window().dimensions(i);
|
||||
int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
|
||||
// Only non reversal window is supported right now.
|
||||
if (!wd.window_reversal() &&
|
||||
dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
|
||||
parallel_spatial_dims.emplace_back(i);
|
||||
lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim);
|
||||
rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim);
|
||||
}
|
||||
}
|
||||
bool lhs_partition_dims_are_parallel =
|
||||
(lhs_parallel_dim_partitions == num_partitions_);
|
||||
bool rhs_partition_dims_are_parallel =
|
||||
(rhs_parallel_dim_partitions == num_partitions_);
|
||||
|
||||
// If there is a parallel dim and all the partitioned dimensions are parallel
|
||||
// dimensions in either LHS or RHS, simply create partitioned convolutions.
|
||||
if (!parallel_spatial_dims.empty() &&
|
||||
(lhs_partition_dims_are_parallel || rhs_partition_dims_are_parallel)) {
|
||||
// Reshard LHS or RHS to partition at parallel dimensions as the other
|
||||
// operand.
|
||||
if (lhs_partition_dims_are_parallel) {
|
||||
rhs = rhs.Reshard(aligned_rhs_sharding);
|
||||
} else {
|
||||
lhs = lhs.Reshard(aligned_lhs_sharding);
|
||||
}
|
||||
auto lhs_shard_shape =
|
||||
MakePartitionedShape(lhs.base_shape(), lhs.sharding());
|
||||
auto rhs_shard_shape =
|
||||
MakePartitionedShape(rhs.base_shape(), rhs.sharding());
|
||||
// Update convolution window.
|
||||
auto new_window = hlo->window();
|
||||
for (const auto& spatial_dim : parallel_spatial_dims) {
|
||||
auto wd = new_window.mutable_dimensions(spatial_dim);
|
||||
wd->set_size(lhs_shard_shape.dimensions(
|
||||
dnums.input_spatial_dimensions(spatial_dim)));
|
||||
wd->set_stride(std::max<int64>(1, wd->size() - 1));
|
||||
wd->set_base_dilation(wd->size());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_shard_shape, rhs_shard_shape, hlo->feature_group_count(),
|
||||
hlo->batch_group_count(), new_window, dnums));
|
||||
*sharded_conv_shape.mutable_layout() = hlo->shape().layout();
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, lhs.hlo(), rhs.hlo(), hlo->feature_group_count(),
|
||||
hlo->batch_group_count(), new_window, dnums,
|
||||
hlo->precision_config()));
|
||||
sharded_conv->set_sharding(hlo->sharding());
|
||||
return PartitionedHlo(sharded_conv, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(hlo->sharding())
|
||||
.hlo();
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Handling cases where both operands' shardings are aligned. We check that
|
||||
// the LHS batch dimension is not partitioned because it is mapped to the
|
||||
// output feature dimension in aligned_rhs_sharding, which are not the same
|
||||
// dimension.
|
||||
if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) {
|
||||
if (options_.conv_halo_exchange_always_on_lhs) {
|
||||
return HandleConvolutionTiledLhsAndRhs(hlo);
|
||||
} else {
|
||||
// Reshard RHS so that each shard computes the partial sum of the full
|
||||
// shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs()
|
||||
// that reshards LHS.
|
||||
//
|
||||
// The size of halo on each dimension can be calculated from the
|
||||
// projection onto the RHS that shard i needs to read. RHS and LHS below
|
||||
// refers to the shard size of RHS and LHS, WC is the number of windows,
|
||||
// and D is the window dilation.
|
||||
//
|
||||
// * offset(i): LHS * i + low_padding - (WC - 1) * stride
|
||||
// * limit(i): LHS * (i + 1) + low_padding
|
||||
//
|
||||
// Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D)
|
||||
// * left-halo: i * RHS - offset(i)
|
||||
// = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding
|
||||
// * right-halo: limit(i) - (i + 1) * RHS
|
||||
// = (i + 1) * (LHS - RHS * D) + low_pading
|
||||
|
||||
auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
|
||||
const HloSharding& rhs_sharding) {
|
||||
// We currently don't support partitioning input batch or output feature
|
||||
// dimensions.
|
||||
return lhs_sharding.tile_assignment().dim(
|
||||
dnums.input_batch_dimension()) != 1 ||
|
||||
rhs_sharding.tile_assignment().dim(
|
||||
dnums.kernel_output_feature_dimension()) != 1;
|
||||
};
|
||||
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
if (ShapeSizeInBytes(lhs.base_shape()) <
|
||||
ShapeSizeInBytes(rhs.base_shape())) {
|
||||
if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
|
||||
rhs = rhs.PadWithValue(zero);
|
||||
} else {
|
||||
if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
lhs = lhs.PadWithValue(zero);
|
||||
rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero);
|
||||
}
|
||||
|
||||
Window window = hlo->window();
|
||||
std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
|
||||
std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
|
||||
std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
|
||||
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
|
||||
int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension);
|
||||
auto wd = window.dimensions(i);
|
||||
if (wd.base_dilation() != 1 || wd.window_reversal()) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
int64 lhs_shard_size = CeilOfRatio(
|
||||
lhs.base_shape().dimensions(lhs_dimension), shard_count);
|
||||
int64 rhs_shard_size = CeilOfRatio(
|
||||
rhs.base_shape().dimensions(rhs_dimension), shard_count);
|
||||
shard_counts[i] = shard_count;
|
||||
lhs_shard_sizes[i] = lhs_shard_size;
|
||||
rhs_shard_sizes[i] = rhs_shard_size;
|
||||
}
|
||||
|
||||
std::vector<OffsetCalculation> left_halo_size_functions(
|
||||
hlo->shape().rank());
|
||||
std::vector<OffsetCalculation> right_halo_size_functions(
|
||||
hlo->shape().rank());
|
||||
Window new_window = window;
|
||||
|
||||
// Data structures needed for Pad and DynamicSlice on LHS if needed.
|
||||
bool need_dynamic_slice_lhs = false;
|
||||
auto partition_ordinals =
|
||||
MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_);
|
||||
std::vector<int64> zero_padding(hlo->shape().rank());
|
||||
PaddingConfig pad_config =
|
||||
window_util::MakeSymmetricPadding(zero_padding);
|
||||
auto zero_s32 = b_.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
|
||||
std::vector<HloInstruction*> dynamic_slice_start_indices(
|
||||
hlo->shape().rank(), zero_s32);
|
||||
Shape dynamic_slice_shape = lhs.hlo()->shape();
|
||||
Shape pad_shape = lhs.hlo()->shape();
|
||||
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs_dimension = dnums.input_spatial_dimensions(i);
|
||||
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
|
||||
int64 lhs_shard_size = lhs_shard_sizes[i];
|
||||
int64 rhs_shard_size = rhs_shard_sizes[i];
|
||||
|
||||
if (shard_counts[i] == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Calculate the left and right halo sizes as described in the comments
|
||||
// above. It calculcates the halo sizes with dilation, so we apply
|
||||
// CeilOfRatio({left,right}_halo_size, window_dilation).
|
||||
auto wd = window.dimensions(i);
|
||||
int64 padding_low = wd.padding_low();
|
||||
int64 padding_high = wd.padding_high();
|
||||
int64 base = lhs.base_shape().dimensions(lhs_dimension);
|
||||
int64 window_count =
|
||||
1 + (padding_low + padding_high + base -
|
||||
(1 + (wd.size() - 1) * wd.window_dilation())) /
|
||||
wd.stride();
|
||||
left_halo_size_functions[rhs_dimension] =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
rhs_shard_size * wd.window_dilation() - lhs_shard_size,
|
||||
(window_count - 1) * wd.stride() - padding_low +
|
||||
wd.window_dilation() - 1,
|
||||
wd.window_dilation()));
|
||||
right_halo_size_functions[rhs_dimension] =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
lhs_shard_size - rhs_shard_size * wd.window_dilation(),
|
||||
lhs_shard_size - rhs_shard_size * wd.window_dilation() +
|
||||
padding_low + wd.window_dilation() - 1,
|
||||
wd.window_dilation()));
|
||||
|
||||
// New RHS window size includes the maximum of both left and right
|
||||
// halos.
|
||||
int64 halo_size = left_halo_size_functions[rhs_dimension].MaxInRange(
|
||||
1, shard_counts[i]) +
|
||||
right_halo_size_functions[rhs_dimension].MaxInRange(
|
||||
0, shard_counts[i] - 1);
|
||||
int64 new_window_size =
|
||||
rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size;
|
||||
|
||||
// The amount of new low padding could be dynamic (e.g., window_dilation
|
||||
// != 1), which requires pad (to the maximum) and dynamic slice on LHS.
|
||||
//
|
||||
// If we consider the first window, the offset of the dilated RHS that
|
||||
// aligns with the first valid LHS element for shard i is 'padding_low +
|
||||
// LHS * i'. When the left halo is added to RHS, the offset of the first
|
||||
// RHS element is (RHS * i - left_halo) * window_dilation. The
|
||||
// difference between the two values is the amount of padding_low we
|
||||
// need on LHS.
|
||||
auto new_padding_low_function =
|
||||
OffsetCalculation(
|
||||
HloOpcode::kMultiply, left_halo_size_functions[rhs_dimension],
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
0, wd.window_dilation(), 1))) -
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
rhs_shard_size * wd.window_dilation() - lhs_shard_size,
|
||||
-padding_low, 1));
|
||||
|
||||
int64 new_padding_low_max =
|
||||
new_padding_low_function.MaxInRange(0, shard_counts[i]);
|
||||
int64 new_padding_low = new_padding_low_max;
|
||||
int64 new_padding_high = window_count * wd.stride() +
|
||||
(new_window_size - 1) * wd.window_dilation() -
|
||||
new_padding_low - lhs_shard_size;
|
||||
|
||||
// We do pad/dynamic-slice only when the padding is dynamic.
|
||||
if (!new_padding_low_function.IsConstant()) {
|
||||
need_dynamic_slice_lhs = true;
|
||||
new_padding_low = 0;
|
||||
pad_config.mutable_dimensions(lhs_dimension)
|
||||
->set_edge_padding_low(new_padding_low_max);
|
||||
pad_config.mutable_dimensions(lhs_dimension)
|
||||
->set_edge_padding_high(new_padding_low_max);
|
||||
pad_shape.set_dimensions(lhs_dimension,
|
||||
lhs_shard_size + 2 * new_padding_low_max);
|
||||
dynamic_slice_start_indices[lhs_dimension] =
|
||||
(OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
0, new_padding_low_max, 1)) -
|
||||
new_padding_low_function)
|
||||
.Calculate(partition_ordinals[lhs_dimension], &b_);
|
||||
dynamic_slice_shape.set_dimensions(
|
||||
lhs_dimension, lhs_shard_size + new_padding_low_max);
|
||||
}
|
||||
|
||||
// Since the convolution RHS operand size increased with halos, adjust
|
||||
// the window config accordingly.
|
||||
new_window.mutable_dimensions(i)->set_padding_low(new_padding_low);
|
||||
new_window.mutable_dimensions(i)->set_padding_high(new_padding_high);
|
||||
new_window.mutable_dimensions(i)->set_size(
|
||||
rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size);
|
||||
}
|
||||
|
||||
HloInstruction* conv_lhs = lhs.hlo();
|
||||
if (need_dynamic_slice_lhs) {
|
||||
auto pad = b_.AddInstruction(
|
||||
HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config));
|
||||
conv_lhs = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
dynamic_slice_shape, pad, dynamic_slice_start_indices,
|
||||
dynamic_slice_shape.dimensions()));
|
||||
}
|
||||
|
||||
// Exchange halo and concatenate.
|
||||
HloInstruction* rhs_with_halo = rhs.hlo();
|
||||
for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
|
||||
int64 dim = dnums.kernel_spatial_dimensions(i);
|
||||
int64 explicit_left_padding_on_full_shape =
|
||||
left_halo_size_functions[dim].Calculate(0);
|
||||
int64 shard_size_with_halo = new_window.dimensions(i).size();
|
||||
|
||||
// offset_on_padded_shape and padded_full_shape_size are needed only if
|
||||
// we want to mask out-of-range values in ExchangeHaloAndGetValidData().
|
||||
// Since the default value for both the collective-permute is zero and
|
||||
// also we call PadWithValue() on both operands at the beginning, we
|
||||
// don't need to mask here.
|
||||
//
|
||||
// TODO(hyoulkee): Consider removing one of the two PadWithValue() calls
|
||||
// if it's always safe.
|
||||
auto offset_on_padded_shape =
|
||||
OffsetCalculation(MultiplyAddDivideOffsetCalculation(
|
||||
rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) -
|
||||
left_halo_size_functions[dim];
|
||||
int64 padded_full_shape_size =
|
||||
offset_on_padded_shape.Calculate(shard_counts[i] - 1) +
|
||||
new_window.dimensions(i).size();
|
||||
auto concat = ExchangeHaloAndGetValidData(
|
||||
rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim],
|
||||
right_halo_size_functions[dim], explicit_left_padding_on_full_shape,
|
||||
padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(),
|
||||
offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_),
|
||||
zero, partition_ordinals[dim], collective_ops_creator_,
|
||||
next_channel_id_, &b_, /*mask_invalid_region=*/false);
|
||||
if (!concat) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
rhs_with_halo = *concat;
|
||||
}
|
||||
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
auto conv = b_.AddInstruction(HloInstruction::CreateConvolve(
|
||||
hlo->shape(), conv_lhs, rhs_with_halo, hlo->feature_group_count(),
|
||||
hlo->batch_group_count(), new_window, dnums,
|
||||
hlo->precision_config()));
|
||||
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
|
||||
NewChannel());
|
||||
ar->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(hlo->sharding())
|
||||
.hlo();
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
if (!sharding.IsTileMaximal()) {
|
||||
// We don't currently support sharding on output feature dimension.
|
||||
if (sharding.tile_assignment().dim(dnums.output_feature_dimension()) > 1) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
// Check if the operand and the output sharding are aligned.
|
||||
std::vector<int64> input_to_output_indices(hlo->shape().rank());
|
||||
input_to_output_indices[dnums.input_batch_dimension()] =
|
||||
dnums.output_batch_dimension();
|
||||
input_to_output_indices[dnums.input_feature_dimension()] =
|
||||
dnums.output_feature_dimension();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
input_to_output_indices[dnums.input_spatial_dimensions(i)] =
|
||||
dnums.output_spatial_dimensions(i);
|
||||
}
|
||||
auto target_operand_sharding =
|
||||
hlo_sharding_util::TransposeSharding(sharding, input_to_output_indices);
|
||||
lhs = lhs.Reshard(target_operand_sharding);
|
||||
|
||||
// Replicate the RHS.
|
||||
rhs = rhs.Reshard(HloSharding::Replicate());
|
||||
|
||||
// Convolution window config does not include batch and feature dimensions,
|
||||
// whereas ReshardAsWindowedInput() expects the same number of window
|
||||
// dimensions as the rank of the operand. So add two more trivial
|
||||
// dimensions.
|
||||
std::vector<int64> ones(hlo->shape().rank(), 1);
|
||||
auto operand_window = window_util::MakeWindow(ones);
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
*operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) =
|
||||
hlo->window().dimensions(i);
|
||||
}
|
||||
|
||||
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
auto resharded_operand_and_window = lhs.ReshardAsWindowedInput(
|
||||
operand_window, target_operand_sharding, zero);
|
||||
if (!resharded_operand_and_window.has_value()) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
Window new_window;
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
*new_window.add_dimensions() =
|
||||
resharded_operand_and_window->shard_window.dimensions(
|
||||
dnums.input_spatial_dimensions(i));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
resharded_operand_and_window->sharded_input->shape(),
|
||||
rhs.hlo()->shape(), hlo->feature_group_count(),
|
||||
hlo->batch_group_count(), new_window, dnums));
|
||||
auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
|
||||
*sharded_conv_shape.mutable_layout() = shard_shape.layout();
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, resharded_operand_and_window->sharded_input,
|
||||
rhs.hlo(), hlo->feature_group_count(), hlo->batch_group_count(),
|
||||
new_window, dnums, hlo->precision_config()));
|
||||
if (!resharded_operand_and_window->dynamic_slice_index_on_output
|
||||
.has_value()) {
|
||||
CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape()));
|
||||
return sharded_conv;
|
||||
}
|
||||
return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
shard_shape, sharded_conv,
|
||||
*resharded_operand_and_window->dynamic_slice_index_on_output,
|
||||
shard_shape.dimensions()));
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
1211
tensorflow/compiler/xla/service/spmd/dot_handler.cc
Normal file
1211
tensorflow/compiler/xla/service/spmd/dot_handler.cc
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -885,5 +885,46 @@ int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
|
||||
return sharding.tile_assignment().dim(dim);
|
||||
}
|
||||
|
||||
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
|
||||
const HloSharding& source, const HloSharding& target) {
|
||||
if (source.IsTileMaximal() || target.IsTileMaximal() ||
|
||||
source.tile_assignment().num_dimensions() !=
|
||||
target.tile_assignment().num_dimensions()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
int64 source_dim = -1;
|
||||
int64 target_dim = -1;
|
||||
for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
|
||||
if (source.tile_assignment().dim(i) > 1 &&
|
||||
target.tile_assignment().dim(i) == 1) {
|
||||
if (source_dim != -1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
source_dim = i;
|
||||
} else if (source.tile_assignment().dim(i) == 1 &&
|
||||
target.tile_assignment().dim(i) > 1) {
|
||||
if (target_dim != -1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
target_dim = i;
|
||||
} else if (source.tile_assignment().dim(i) !=
|
||||
target.tile_assignment().dim(i)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
if (source_dim == -1 || target_dim == -1 || source_dim == target_dim) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return std::pair<int64, int64>(source_dim, target_dim);
|
||||
}
|
||||
|
||||
bool CanReshardWithCollectivePermute(const HloSharding& source,
|
||||
const HloSharding& target) {
|
||||
return !source.IsTileMaximal() && !target.IsTileMaximal() &&
|
||||
source.tile_assignment().dimensions() ==
|
||||
target.tile_assignment().dimensions() &&
|
||||
source.tile_assignment() != target.tile_assignment();
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
@ -265,6 +265,15 @@ HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
|
||||
// Check if a dimension is sharded.
|
||||
int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
|
||||
|
||||
// Returns the pair of source and target dimensions is the resharding can be
|
||||
// done via all-to-all.
|
||||
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
|
||||
const HloSharding& source, const HloSharding& target);
|
||||
|
||||
// Returns whether the resharding can be done via collective-permute.
|
||||
bool CanReshardWithCollectivePermute(const HloSharding& source,
|
||||
const HloSharding& target);
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
|
@ -270,8 +270,8 @@ message DebugOptions {
|
||||
// Paths to files with ptx code.
|
||||
repeated string xla_gpu_ptx_file = 127;
|
||||
|
||||
// Blacklist for cuDNN convolutions.
|
||||
string xla_gpu_algorithm_blacklist_path = 128;
|
||||
// Denylist for cuDNN convolutions.
|
||||
string xla_gpu_algorithm_denylist_path = 128;
|
||||
|
||||
// Guarantee run-to-run determinism from reductions on XLA:GPU.
|
||||
bool xla_gpu_deterministic_reductions = 130;
|
||||
|
@ -163,6 +163,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:platform_port",
|
||||
"//tensorflow/core/util:abstract_stack_trace",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
|
@ -306,6 +306,7 @@ Status EagerOperation::Reset(
|
||||
}
|
||||
attrs_.Reset(op);
|
||||
use_xla_ = false;
|
||||
stack_trace_.reset();
|
||||
is_function_ = is_function;
|
||||
cancellation_manager_ = nullptr;
|
||||
executor_ = executor ? executor : &ctx_.Executor();
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/util/abstract_stack_trace.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -120,6 +121,14 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
|
||||
Status SetUseXla(bool enable) override;
|
||||
|
||||
void SetStackTrace(AbstractStackTrace stack_trace) override {
|
||||
stack_trace_ = stack_trace;
|
||||
}
|
||||
|
||||
absl::optional<AbstractStackTrace> GetStackTrace() override {
|
||||
return stack_trace_;
|
||||
}
|
||||
|
||||
Status Reset(const char* op, const char* device_name, bool remote,
|
||||
EagerExecutor* executor,
|
||||
const absl::optional<EagerRemoteFunctionParams>
|
||||
@ -218,6 +227,7 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
VariantDevice device_;
|
||||
|
||||
bool use_xla_ = false;
|
||||
absl::optional<AbstractStackTrace> stack_trace_;
|
||||
bool is_function_; // Conceptually const, but can't be because of Reset
|
||||
bool colocation_exempt_;
|
||||
CancellationManager* cancellation_manager_ = nullptr; // Not owned.
|
||||
|
@ -634,7 +634,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
auto node = absl::make_unique<AsyncExecuteNode>(
|
||||
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
||||
graph_collector, op->GetCancellationManager(),
|
||||
absl::Span<TensorHandle*>(retvals, num_outputs));
|
||||
absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
|
||||
// Release the inputs from the eager operation since the AsyncExecuteNode
|
||||
// would have taken ownership. This allows the inputs to be forwarded if
|
||||
// possible.
|
||||
|
@ -150,14 +150,16 @@ class AsyncExecuteNode : public EagerNode {
|
||||
core::RefCountPtr<KernelAndDevice> kernel,
|
||||
GraphCollector* graph_collector,
|
||||
CancellationManager* cancellation_manager,
|
||||
absl::Span<TensorHandle*> retvals)
|
||||
absl::Span<TensorHandle*> retvals,
|
||||
absl::optional<AbstractStackTrace> stack_trace)
|
||||
: EagerNode(),
|
||||
ctx_(ctx),
|
||||
inputs_(inputs),
|
||||
remote_func_params_(remote_func_params),
|
||||
kernel_(std::move(kernel)),
|
||||
graph_collector_(graph_collector),
|
||||
cancellation_manager_(cancellation_manager) {
|
||||
cancellation_manager_(cancellation_manager),
|
||||
stack_trace_(stack_trace) {
|
||||
// Copy the output handles, since the container for them might get
|
||||
// destroyed.
|
||||
for (auto handle : retvals) {
|
||||
@ -194,10 +196,14 @@ class AsyncExecuteNode : public EagerNode {
|
||||
}
|
||||
++i;
|
||||
}
|
||||
const Status status = EagerKernelExecute(
|
||||
Status status = EagerKernelExecute(
|
||||
ctx_, inputs_, remote_func_params_, kernel_, graph_collector_,
|
||||
cancellation_manager_, absl::MakeSpan(retvals_));
|
||||
if (!status.ok()) {
|
||||
if (stack_trace_.has_value()) {
|
||||
status = Status(status.code(), status.error_message(),
|
||||
stack_trace_->ToStackFrames());
|
||||
}
|
||||
Abort(status);
|
||||
return status;
|
||||
}
|
||||
@ -227,6 +233,7 @@ class AsyncExecuteNode : public EagerNode {
|
||||
core::RefCountPtr<KernelAndDevice> kernel_;
|
||||
GraphCollector* graph_collector_;
|
||||
CancellationManager* const cancellation_manager_;
|
||||
absl::optional<AbstractStackTrace> stack_trace_;
|
||||
absl::InlinedVector<TensorHandle*, 2> retvals_;
|
||||
};
|
||||
|
||||
|
@ -159,6 +159,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/core/platform:tf32_utils",
|
||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
||||
"//tensorflow/core/profiler/lib:scoped_annotation",
|
||||
"//third_party/eigen3",
|
||||
|
@ -277,6 +277,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
|
||||
// RendezvousMgr already aborted, shouldn't send RPC call any more
|
||||
if (!call->status().ok()) {
|
||||
DeregisterCall(call);
|
||||
// NOTE: `*sess` can potentially be deleted before we return from
|
||||
// `call->done()(...)`, so we must release the worker before calling the
|
||||
// callback.
|
||||
|
@ -972,9 +972,9 @@ double Node::OutputTime(absl::flat_hash_map<string, double>* input_times,
|
||||
return output_times[long_name()];
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> Node::Snapshot(std::shared_ptr<Node> output) const {
|
||||
std::shared_ptr<Node> Node::Snapshot() const {
|
||||
NodePairList node_pairs;
|
||||
auto result = SnapshotHelper(output, &node_pairs);
|
||||
auto result = SnapshotHelper(nullptr, &node_pairs);
|
||||
|
||||
while (!node_pairs.empty()) {
|
||||
auto node_pair = node_pairs.front();
|
||||
@ -1346,7 +1346,7 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget) {
|
||||
std::shared_ptr<Node> snapshot;
|
||||
{
|
||||
tf_shared_lock lock(mu_);
|
||||
snapshot = output_->Snapshot(nullptr);
|
||||
snapshot = output_->Snapshot();
|
||||
}
|
||||
VLOG(2) << "Starting optimization of tunable parameters with GradientDescent";
|
||||
auto parameters = CollectTunableParameters(snapshot);
|
||||
@ -1422,7 +1422,7 @@ void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget) {
|
||||
std::shared_ptr<Node> snapshot;
|
||||
{
|
||||
tf_shared_lock lock(mu_);
|
||||
snapshot = output_->Snapshot(nullptr);
|
||||
snapshot = output_->Snapshot();
|
||||
}
|
||||
VLOG(2) << "Starting optimization of tunable parameters with HillClimb";
|
||||
const double processing_time = TotalProcessingTime(snapshot);
|
||||
|
@ -339,8 +339,7 @@ class Node {
|
||||
//
|
||||
// The purpose for this method is to allow the model optimization logic to
|
||||
// operate over immutable state while allowing concurrent model updates.
|
||||
std::shared_ptr<Node> Snapshot(std::shared_ptr<Node> output) const
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
std::shared_ptr<Node> Snapshot() const TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Returns the per-element processing time spent in this node.
|
||||
double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);
|
||||
|
@ -756,7 +756,7 @@ TEST(SnapshotTest, Model) {
|
||||
cur_node = cur_node->inputs().front();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> root_copy = root->Snapshot(nullptr);
|
||||
std::shared_ptr<Node> root_copy = root->Snapshot();
|
||||
cur_node = root;
|
||||
std::shared_ptr<Node> cur_node_copy = root_copy;
|
||||
|
||||
|
@ -293,7 +293,7 @@ class NodeTypeAttrMap {
|
||||
}
|
||||
// Note that the mappings generated here include inputs/outputs with fixed
|
||||
// types. This makes the mappings complete (all inputs and outputs are
|
||||
// included), and allows the graph rewriter to propagate black paint
|
||||
// included), and allows the graph rewriter to propagate deny paint
|
||||
// from/through ops with fixed types.
|
||||
io2type_entry.first.reserve(input_arg_inds.size());
|
||||
for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
|
||||
@ -843,10 +843,10 @@ DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
|
||||
}
|
||||
|
||||
Status ValidateLists(const gtl::FlatSet<string>& allow_list,
|
||||
const gtl::FlatSet<string>& black_list,
|
||||
const gtl::FlatSet<string>& gray_list,
|
||||
const gtl::FlatSet<string>& deny_list,
|
||||
const gtl::FlatSet<string>& infer_list,
|
||||
const gtl::FlatSet<string>& clear_list) {
|
||||
std::vector<gtl::FlatSet<string>> lists{allow_list, black_list, gray_list,
|
||||
std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list,
|
||||
clear_list};
|
||||
std::multiset<string> counts;
|
||||
for (const auto& list : lists) {
|
||||
@ -967,23 +967,23 @@ class AutoMixedPrecisionImpl {
|
||||
bool SupportsF16(const NodeTypeId& node_type) const;
|
||||
const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
|
||||
bool IsSourceOrSinkOp(const string& op) const;
|
||||
void FindFloat32TensorListOpClustersAndBlacklistUnsafe(
|
||||
void FindFloat32TensorListOpClustersAndDenylistUnsafe(
|
||||
std::vector<absl::flat_hash_set<const NodeDef*>>* clusters,
|
||||
absl::flat_hash_set<int>* black_set) const;
|
||||
absl::flat_hash_set<int>* deny_set) const;
|
||||
void FindTensorListImplicitFloat32Edges(
|
||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||
std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
|
||||
void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
|
||||
void PropagateBlackFwdThroughClearAndGray(
|
||||
absl::flat_hash_set<int>* black_set) const;
|
||||
void PropagateDenyFwdThroughClearAndInfer(
|
||||
absl::flat_hash_set<int>* deny_set) const;
|
||||
void ForceColorMatchBetweenTensorListOps(
|
||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||
absl::flat_hash_set<int>* allow_set,
|
||||
absl::flat_hash_set<int>* black_set) const;
|
||||
void AddClearAndGrayToAllowIfBetweenAllow(
|
||||
const absl::flat_hash_set<int>& black_set,
|
||||
absl::flat_hash_set<int>* deny_set) const;
|
||||
void AddClearAndInferToAllowIfBetweenAllow(
|
||||
const absl::flat_hash_set<int>& deny_set,
|
||||
absl::flat_hash_set<int>* allow_set) const;
|
||||
void PropagateAllowThroughClear(const absl::flat_hash_set<int>& black_set,
|
||||
void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set,
|
||||
absl::flat_hash_set<int>* allow_set) const;
|
||||
Status ForceColorMatchOnRecurrentEdges(
|
||||
absl::flat_hash_set<int>* allow_set) const;
|
||||
@ -1006,8 +1006,8 @@ class AutoMixedPrecisionImpl {
|
||||
bool force_all_fp16_;
|
||||
AutoMixedPrecisionMode mode_;
|
||||
gtl::FlatSet<string> f16_allowlist_;
|
||||
gtl::FlatSet<string> f16_blacklist_;
|
||||
gtl::FlatSet<string> f16_graylist_;
|
||||
gtl::FlatSet<string> f16_denylist_;
|
||||
gtl::FlatSet<string> f16_inferlist_;
|
||||
gtl::FlatSet<string> f16_clearlist_;
|
||||
absl::flat_hash_set<const NodeDef*> should_process_nodes_;
|
||||
DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16
|
||||
@ -1083,12 +1083,12 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
|
||||
for (const auto& x : mp_lists->AllowList()) {
|
||||
f << x << "\n";
|
||||
}
|
||||
f << "\nBlackList:\n";
|
||||
for (const auto& x : mp_lists->BlackList()) {
|
||||
f << "\nDenyList:\n";
|
||||
for (const auto& x : mp_lists->DenyList()) {
|
||||
f << x << "\n";
|
||||
}
|
||||
f << "\nGrayList:\n";
|
||||
for (const auto& x : mp_lists->GrayList()) {
|
||||
f << "\nInferList:\n";
|
||||
for (const auto& x : mp_lists->InferList()) {
|
||||
f << x << "\n";
|
||||
}
|
||||
f << "\nClearList:\n";
|
||||
@ -1255,11 +1255,11 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||
get_mixed_precision_lists();
|
||||
f16_allowlist_ = mp_lists->AllowList();
|
||||
f16_blacklist_ = mp_lists->BlackList();
|
||||
f16_graylist_ = mp_lists->GrayList();
|
||||
f16_denylist_ = mp_lists->DenyList();
|
||||
f16_inferlist_ = mp_lists->InferList();
|
||||
f16_clearlist_ = mp_lists->ClearList();
|
||||
TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_blacklist_,
|
||||
f16_graylist_, f16_clearlist_));
|
||||
TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
|
||||
f16_inferlist_, f16_clearlist_));
|
||||
|
||||
size_t timestamp = Env::Default()->NowMicros() / 1000;
|
||||
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
|
||||
@ -1294,11 +1294,11 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
|
||||
|
||||
absl::flat_hash_set<int> black_set;
|
||||
absl::flat_hash_set<int> deny_set;
|
||||
|
||||
std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
|
||||
FindFloat32TensorListOpClustersAndBlacklistUnsafe(&tensor_list_clusters,
|
||||
&black_set);
|
||||
FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
|
||||
&deny_set);
|
||||
std::vector<NodeTypeIdEdge> ephemeral_edges;
|
||||
for (const auto& cluster : tensor_list_clusters) {
|
||||
VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
|
||||
@ -1320,14 +1320,14 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
// This is done under the assumption that allowlist ops are always
|
||||
// numerically-safe in f16 and that they are the most important ops for
|
||||
// improving performance.
|
||||
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka
|
||||
// "blacklist" ops) or they are on a forward path from a blacklist node to
|
||||
// a black/gray node (including the node at the end of the path) through
|
||||
// non-numerically-dangerous ops (aka "greylist" and "clearlist" ops).
|
||||
// 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
|
||||
// "denylist" ops) or they are on a forward path from a denylist node to
|
||||
// a deny/infer node (including the node at the end of the path) through
|
||||
// non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
|
||||
// This is done to prevent numerically-dangerous ops and their downstream
|
||||
// effects from being changed to f16, which would risk breaking the
|
||||
// numerical accuracy of the model.
|
||||
// 3) For all remaining nodes that are not considered dangerous (greylist
|
||||
// 3) For all remaining nodes that are not considered dangerous (inferlist
|
||||
// and clearlist ops), find those that are between (i.e., both upstream
|
||||
// and downstream of) allow nodes, and add them to the allow_set.
|
||||
// This is done to avoid unnecessary casts between allowlist ops.
|
||||
@ -1346,29 +1346,29 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(2) << "Beginning pass 2 to propagate black forwards from blacklist ops "
|
||||
"through clear/graylist ops";
|
||||
PropagateBlackFwdThroughClearAndGray(&black_set);
|
||||
VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
|
||||
"through clear/inferlist ops";
|
||||
PropagateDenyFwdThroughClearAndInfer(&deny_set);
|
||||
VLOG(2) << "Finished pass 2";
|
||||
|
||||
VLOG(2) << "Forcing color match between data structure ops";
|
||||
for (const auto& cluster : tensor_list_clusters) {
|
||||
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set);
|
||||
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
|
||||
}
|
||||
|
||||
VLOG(2) << "Beginning pass 3 to set clear and gray nodes to allow if they "
|
||||
VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
|
||||
"are between allow ops";
|
||||
AddClearAndGrayToAllowIfBetweenAllow(black_set, &allow_set);
|
||||
AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
|
||||
VLOG(2) << "Finished pass 3";
|
||||
|
||||
VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
|
||||
"clearlist ops";
|
||||
PropagateAllowThroughClear(black_set, &allow_set);
|
||||
PropagateAllowThroughClear(deny_set, &allow_set);
|
||||
VLOG(2) << "Finished pass 4";
|
||||
|
||||
VLOG(2) << "Forcing color match between data structure ops";
|
||||
for (const auto& cluster : tensor_list_clusters) {
|
||||
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set);
|
||||
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
|
||||
}
|
||||
|
||||
VLOG(2) << "Forcing color match on loop edges";
|
||||
@ -1426,11 +1426,11 @@ bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const {
|
||||
// Finds all clusters of float32 Tensor List nodes that are connected via their
|
||||
// handle edges. Unsafe clusters (those with unprocessable nodes, or with edges
|
||||
// that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc.
|
||||
// nodes) are added to black_set. The caller should paint all nodes in a cluster
|
||||
// nodes) are added to deny_set. The caller should paint all nodes in a cluster
|
||||
// the same color, as they may all refer to the same Tensor List.
|
||||
void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndBlacklistUnsafe(
|
||||
void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe(
|
||||
std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters,
|
||||
absl::flat_hash_set<int>* black_set) const {
|
||||
absl::flat_hash_set<int>* deny_set) const {
|
||||
absl::flat_hash_set<const NodeDef*> tensor_list_prop_set;
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
@ -1463,7 +1463,7 @@ void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndBlacklistUnsafe(
|
||||
cluster.insert(node);
|
||||
if (!ShouldProcess(*node)) {
|
||||
// The cluster contains an un-processable node.
|
||||
black_set->insert(root_fp32_idx);
|
||||
deny_set->insert(root_fp32_idx);
|
||||
}
|
||||
// TODO(benbarsdell): In a theoretical pathological
|
||||
// case of a Tensor List of Tensor List handles, the
|
||||
@ -1471,7 +1471,7 @@ void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndBlacklistUnsafe(
|
||||
// sink.
|
||||
} else if (IsSourceOrSinkOp(node->op())) {
|
||||
// The cluster crosses an untraversable boundary.
|
||||
black_set->insert(root_fp32_idx);
|
||||
deny_set->insert(root_fp32_idx);
|
||||
}
|
||||
}));
|
||||
tensor_list_clusters->push_back(cluster);
|
||||
@ -1534,21 +1534,21 @@ void AutoMixedPrecisionImpl::AddAllowlistOps(
|
||||
}
|
||||
}
|
||||
|
||||
// Adds nodes to black_set iff they are on the blacklist or they are on a
|
||||
// forward path from a blacklist node to a black/gray node (including the node
|
||||
// at the end of the path) through clear and gray nodes.
|
||||
// E.g., black -> gray -> clear -> gray -> clear -> allow -> gray
|
||||
// becomes: black -> black -> black -> black -> clear -> allow -> gray.
|
||||
void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
||||
absl::flat_hash_set<int>* black_set) const {
|
||||
// Adds nodes to deny_set iff they are on the denylist or they are on a
|
||||
// forward path from a denylist node to a deny/infer node (including the node
|
||||
// at the end of the path) through clear and infer nodes.
|
||||
// E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer
|
||||
// becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer.
|
||||
void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer(
|
||||
absl::flat_hash_set<int>* deny_set) const {
|
||||
if (force_all_fp16_) return;
|
||||
|
||||
// Find clear nodes that are upstream of black or gray.
|
||||
absl::flat_hash_set<int> upstream_of_black_or_gray_set;
|
||||
// Find clear nodes that are upstream of deny or infer.
|
||||
absl::flat_hash_set<int> upstream_of_deny_or_infer_set;
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!(f16_blacklist_.count(root.node->op()) ||
|
||||
f16_graylist_.count(root.node->op()))) {
|
||||
if (!(f16_denylist_.count(root.node->op()) ||
|
||||
f16_inferlist_.count(root.node->op()))) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(graph_type_view_, {&root},
|
||||
@ -1556,42 +1556,42 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
return idx == root_idx ||
|
||||
(!upstream_of_black_or_gray_set.count(idx) &&
|
||||
(!upstream_of_deny_or_infer_set.count(idx) &&
|
||||
f16_clearlist_.count(item.node->op()));
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||
upstream_of_black_or_gray_set.insert(idx);
|
||||
upstream_of_deny_or_infer_set.insert(idx);
|
||||
}));
|
||||
}
|
||||
|
||||
// Propagate black forward through nodes in upstream_of_black_or_gray_set.
|
||||
// Propagate deny forward through nodes in upstream_of_deny_or_infer_set.
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (black_set->count(root_idx) || !f16_blacklist_.count(root.node->op())) {
|
||||
if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(
|
||||
graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
|
||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||
return idx == root_idx || (!black_set->count(idx) &&
|
||||
upstream_of_black_or_gray_set.count(idx));
|
||||
return idx == root_idx || (!deny_set->count(idx) &&
|
||||
upstream_of_deny_or_infer_set.count(idx));
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||
bool inserted = black_set->insert(idx).second;
|
||||
bool inserted = deny_set->insert(idx).second;
|
||||
if (VLOG_IS_ON(2) && inserted) {
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
VLOG(2) << "Painting type " << item.type_attr.DebugString()
|
||||
<< " of " << item.node->op() << " node "
|
||||
<< item.node->name() << " BLACK";
|
||||
<< item.node->name() << " DENY";
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow(
|
||||
const absl::flat_hash_set<int>& black_set,
|
||||
void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow(
|
||||
const absl::flat_hash_set<int>& deny_set,
|
||||
absl::flat_hash_set<int>* allow_set) const {
|
||||
// Find clear/graylist ops that are downstream of allow ops.
|
||||
// Find clear/inferlist ops that are downstream of allow ops.
|
||||
absl::flat_hash_set<int> downstream_of_allow_set;
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
@ -1605,13 +1605,13 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow(
|
||||
return idx == root_idx ||
|
||||
(!downstream_of_allow_set.count(idx) &&
|
||||
!f16_allowlist_.count(item.node->op()) &&
|
||||
!black_set.count(idx) && ShouldProcess(*item.node) &&
|
||||
!deny_set.count(idx) && ShouldProcess(*item.node) &&
|
||||
// TODO(benbarsdell): Consider allowing propagation through
|
||||
// ops that are already float16 in order to reduce the number
|
||||
// of casts.
|
||||
IsFloat32(item) && SupportsF16(item) &&
|
||||
(f16_clearlist_.count(item.node->op()) ||
|
||||
f16_graylist_.count(item.node->op())));
|
||||
f16_inferlist_.count(item.node->op())));
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder(
|
||||
[&](int idx) { downstream_of_allow_set.insert(idx); }));
|
||||
@ -1645,7 +1645,7 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow(
|
||||
}
|
||||
|
||||
void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
|
||||
const absl::flat_hash_set<int>& black_set,
|
||||
const absl::flat_hash_set<int>& deny_set,
|
||||
absl::flat_hash_set<int>* allow_set) const {
|
||||
// Propagate allow from allow nodes through clearlist ops.
|
||||
absl::flat_hash_set<int> clear_prop_set;
|
||||
@ -1661,7 +1661,7 @@ void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
|
||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
return idx == root_idx ||
|
||||
(!allow_set->count(idx) && !black_set.count(idx) &&
|
||||
(!allow_set->count(idx) && !deny_set.count(idx) &&
|
||||
ShouldProcess(*item.node) && IsFloat32(item) &&
|
||||
SupportsF16(item) &&
|
||||
(f16_clearlist_.count(item.node->op())) &&
|
||||
@ -1727,14 +1727,14 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
||||
if (allow_set->erase(merge_idx)) {
|
||||
VLOG(2) << "Painting type T of Merge node "
|
||||
<< graph_type_view_.GetNode(merge_idx)->node->name()
|
||||
<< " BLACK to match the color of its sibling Merge nodes "
|
||||
<< " DENY to match the color of its sibling Merge nodes "
|
||||
"with common NextIteration node "
|
||||
<< node.name();
|
||||
}
|
||||
}
|
||||
if (allow_set->erase(nextiter_idx)) {
|
||||
VLOG(2) << "Painting type T of NextIteration node " << node.name()
|
||||
<< " BLACK to match the color of its output Merge node(s)";
|
||||
<< " DENY to match the color of its output Merge node(s)";
|
||||
}
|
||||
} else {
|
||||
if (allow_set->insert(nextiter_idx).second) {
|
||||
@ -1751,8 +1751,8 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
||||
void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
|
||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||
absl::flat_hash_set<int>* allow_set,
|
||||
absl::flat_hash_set<int>* black_set) const {
|
||||
bool any_black = false;
|
||||
absl::flat_hash_set<int>* deny_set) const {
|
||||
bool any_deny = false;
|
||||
bool any_allow = false;
|
||||
std::vector<int> node_type_idxs;
|
||||
node_type_idxs.reserve(tensor_list_nodes.size());
|
||||
@ -1766,24 +1766,24 @@ void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
|
||||
node_type_idxs.push_back(maybe_node_type_idx.value());
|
||||
}
|
||||
for (int node_type_idx : node_type_idxs) {
|
||||
if (black_set->count(node_type_idx)) {
|
||||
any_black = true;
|
||||
if (deny_set->count(node_type_idx)) {
|
||||
any_deny = true;
|
||||
break;
|
||||
} else if (allow_set->count(node_type_idx)) {
|
||||
any_allow = true;
|
||||
}
|
||||
}
|
||||
if (!any_black && !any_allow) return;
|
||||
if (!any_deny && !any_allow) return;
|
||||
for (int node_type_idx : node_type_idxs) {
|
||||
const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
|
||||
VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
|
||||
<< node_type.node->op() << " node " << node_type.node->name() << " "
|
||||
<< (any_black ? "BLACK" : "ALLOW")
|
||||
<< (any_deny ? "DENY" : "ALLOW")
|
||||
<< " because at least one of its siblings is "
|
||||
<< (any_black ? "BLACK" : "ALLOW");
|
||||
if (any_black) {
|
||||
<< (any_deny ? "DENY" : "ALLOW");
|
||||
if (any_deny) {
|
||||
allow_set->erase(node_type_idx);
|
||||
black_set->insert(node_type_idx);
|
||||
deny_set->insert(node_type_idx);
|
||||
} else {
|
||||
allow_set->insert(node_type_idx);
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Represents the four lists of ops: the allow list, gray list, black list, and
|
||||
// Represents the four lists of ops: the allow list, infer list, deny list, and
|
||||
// clear list. These lists determine which ops are converted to fp16/bf16
|
||||
// (referred to as 'f16' for short) and which ops stay as fp32.
|
||||
class AutoMixedPrecisionLists {
|
||||
@ -36,13 +36,13 @@ class AutoMixedPrecisionLists {
|
||||
virtual gtl::FlatSet<string> AllowList() = 0;
|
||||
// Returns the set of ops that can run in f16 and are considered numerically-
|
||||
// safe (for execution in f16), but which may be made unsafe by an upstream
|
||||
// blacklist op.
|
||||
virtual gtl::FlatSet<string> GrayList() = 0;
|
||||
// denylist op.
|
||||
virtual gtl::FlatSet<string> InferList() = 0;
|
||||
// Returns the set of ops that are considered numerically-dangerous (i.e.,
|
||||
// unsafe for execution in f16) and whose effects may also be observed in
|
||||
// downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to
|
||||
// the Exp).
|
||||
virtual gtl::FlatSet<string> BlackList() = 0;
|
||||
virtual gtl::FlatSet<string> DenyList() = 0;
|
||||
// Returns the set of ops that do not have numerically-significant effects
|
||||
// (i.e., they are always considered safe for execution in f16 precision), and
|
||||
// can run in f16.
|
||||
@ -51,10 +51,11 @@ class AutoMixedPrecisionLists {
|
||||
protected:
|
||||
// Adds or removes ops from list if certain environmental variables are set.
|
||||
static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
|
||||
CHECK(list_name == "ALLOWLIST" || list_name == "GRAYLIST" || // Crash OK.
|
||||
list_name == "BLACKLIST" || list_name == "CLEARLIST" ||
|
||||
CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" || // Crash OK.
|
||||
list_name == "DENYLIST" || list_name == "CLEARLIST" ||
|
||||
// TODO(reedwm): for bkwds compat; remove when no longer necessary:
|
||||
list_name == "WHITELIST");
|
||||
list_name == "WHITELIST" || list_name == "GRAYLIST" ||
|
||||
list_name == "BLACKLIST");
|
||||
string add_env_var =
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
|
||||
string remove_env_var =
|
||||
@ -154,7 +155,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||
return list;
|
||||
}
|
||||
|
||||
gtl::FlatSet<string> GrayList() override {
|
||||
gtl::FlatSet<string> InferList() override {
|
||||
if (IsPseudoFastMath()) {
|
||||
return gtl::FlatSet<string>{};
|
||||
}
|
||||
@ -204,11 +205,14 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||
"Tanh",
|
||||
"TanhGrad",
|
||||
};
|
||||
UpdateList("INFERLIST", &list);
|
||||
// For backwards compatibility, keeping the original env variable here.
|
||||
// TODO(reedwm): This should be removed if we don't have active users.
|
||||
UpdateList("GRAYLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
gtl::FlatSet<string> BlackList() override {
|
||||
gtl::FlatSet<string> DenyList() override {
|
||||
if (IsPseudoFastMath()) {
|
||||
return gtl::FlatSet<string>{};
|
||||
}
|
||||
@ -224,6 +228,9 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
"Sum",
|
||||
};
|
||||
UpdateList("DENYLIST", &list);
|
||||
// For backwards compatibility, keeping the original env variable here.
|
||||
// TODO(reedwm): This should be removed if we don't have active users.
|
||||
UpdateList("BLACKLIST", &list);
|
||||
return list;
|
||||
}
|
||||
@ -344,7 +351,7 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
||||
AutoMixedPrecisionListsMkl() {}
|
||||
|
||||
// Only ops which are supported by MKL in bfloat16 should be added to the
|
||||
// allow list, gray list, or clear list.
|
||||
// allow list, infer list, or clear list.
|
||||
gtl::FlatSet<string> AllowList() override {
|
||||
auto list = gtl::FlatSet<string>{"Conv2D",
|
||||
"Conv2DBackpropFilter",
|
||||
@ -360,10 +367,13 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
||||
"BatchMatMulV2"};
|
||||
|
||||
UpdateList("ALLOWLIST", &list);
|
||||
// For backwards compatibility, keeping the original env variable here.
|
||||
// TODO(reedwm): This should be removed if we don't have active users.
|
||||
UpdateList("WHITELIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
gtl::FlatSet<string> GrayList() override {
|
||||
gtl::FlatSet<string> InferList() override {
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"Add",
|
||||
"AddN",
|
||||
@ -384,11 +394,14 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
||||
"Mul",
|
||||
"Sub",
|
||||
};
|
||||
UpdateList("INFERLIST", &list);
|
||||
// For backwards compatibility, keeping the original env variable here.
|
||||
// TODO(reedwm): This should be removed if we don't have active users.
|
||||
UpdateList("GRAYLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
gtl::FlatSet<string> BlackList() override {
|
||||
gtl::FlatSet<string> DenyList() override {
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"Exp",
|
||||
"Expm1",
|
||||
@ -401,6 +414,9 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
"Sum",
|
||||
};
|
||||
UpdateList("DENYLIST", &list);
|
||||
// For backwards compatibility, keeping the original env variable here.
|
||||
// TODO(reedwm): This should be removed if we don't have active users.
|
||||
UpdateList("BLACKLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
return AddNode(name, op, inputs, attributes, graph);
|
||||
}
|
||||
|
||||
void TestSimpleUnaryGrayOp(
|
||||
void TestSimpleUnaryInferOp(
|
||||
double input_min, double input_max, double atol, double rtol,
|
||||
const std::function<Output(const tensorflow::Scope&, Output)>&
|
||||
test_op_factory) {
|
||||
@ -170,8 +170,8 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
GenerateIdentityMatrix<DT_FLOAT>(size, size));
|
||||
Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
|
||||
Output gry1 = test_op_factory(s.WithOpName("gry1"), allow1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, eye);
|
||||
Output infer1 = test_op_factory(s.WithOpName("infer1"), allow1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, eye);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch1"};
|
||||
@ -191,7 +191,7 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
|
||||
DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch, feed);
|
||||
@ -209,10 +209,10 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
TEST_F(AutoMixedPrecisionTest, NoOp) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
|
||||
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||
Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
|
||||
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -230,9 +230,9 @@ TEST_F(AutoMixedPrecisionTest, NoOp) {
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
@ -284,16 +284,16 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
|
||||
TEST_F(AutoMixedPrecisionTest, Simple) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||
Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
|
||||
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
|
||||
Output gry2 = ops::Log(s.WithOpName("gry2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), gry2);
|
||||
Output blk2 = ops::SparseMatMul(s.WithOpName("blk2"), clr4, clr4);
|
||||
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk2);
|
||||
Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
|
||||
Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
|
||||
Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -310,16 +310,16 @@ TEST_F(AutoMixedPrecisionTest, Simple) {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("Ta").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("Tb").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
@ -374,13 +374,13 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||
Output blk1 = ops::Exp(s.WithOpName("blk1"), gry1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), blk1);
|
||||
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
|
||||
Output deny1 = ops::Exp(s.WithOpName("deny1"), infer1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), deny1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
|
||||
Output blk2 = ops::Exp(s.WithOpName("blk2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
||||
Output deny2 = ops::Exp(s.WithOpName("deny2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"allow1", "clr2", "clr3"};
|
||||
@ -398,12 +398,12 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
@ -419,11 +419,11 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1);
|
||||
Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
|
||||
Output allow2 =
|
||||
ops::MatMul(s.WithOpName("allow2").WithDevice(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
gry1, gry1);
|
||||
infer1, infer1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||
|
||||
@ -443,7 +443,7 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||
|
||||
@ -521,9 +521,9 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
||||
s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
|
||||
fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
|
||||
.x_backprop;
|
||||
Output gry1 = ops::Add(s.WithOpName("gry1"), fbn1, bng1);
|
||||
Output infer1 = ops::Add(s.WithOpName("infer1"), fbn1, bng1);
|
||||
Output allow2 =
|
||||
ops::Conv2D(s.WithOpName("allow2"), gry1, weight, {1, 1, 1, 1}, "SAME",
|
||||
ops::Conv2D(s.WithOpName("allow2"), infer1, weight, {1, 1, 1, 1}, "SAME",
|
||||
ops::Conv2D::DataFormat("NHWC"));
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
|
||||
|
||||
@ -547,7 +547,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
||||
EXPECT_EQ(output_view.GetNode("bng1")->op(), "FusedBatchNormGradV2");
|
||||
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
@ -563,10 +563,10 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
|
||||
Output gry1 =
|
||||
ops::AddN(s.WithOpName("gry1"),
|
||||
Output infer1 =
|
||||
ops::AddN(s.WithOpName("infer1"),
|
||||
{clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -587,7 +587,7 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
|
||||
for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
|
||||
EXPECT_EQ(type, DT_HALF);
|
||||
}
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
@ -633,17 +633,17 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) {
|
||||
TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
|
||||
Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
|
||||
Output ent1 =
|
||||
ops::internal::Enter(s.WithOpName("ent1"), blk1, "loop1").output;
|
||||
ops::internal::Enter(s.WithOpName("ent1"), deny1, "loop1").output;
|
||||
// Note that the second input is later replaced with "nxt1".
|
||||
Output mrg1 = ops::Merge(s.WithOpName("mrg1"), {ent1, ent1}).output;
|
||||
// For simplicity, the loop condition is constant false.
|
||||
Output con1 = ops::Const(s.WithOpName("con1"), false, {});
|
||||
Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
|
||||
auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), swt1.output_true);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), gry1, gry1);
|
||||
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), swt1.output_true);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), infer1, infer1);
|
||||
Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
|
||||
Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
|
||||
@ -671,14 +671,14 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
// Note that mrg1 gets painted black because it is between blk1 and gry1. This
|
||||
// forces nxt1 and mrg2 to be painted black as well (they would otherwise be
|
||||
// painted allow because they are clear and have a direct path to allow1).
|
||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||
// Note that mrg1 gets painted deny because it is between deny1 and infer1.
|
||||
// This forces nxt1 and mrg2 to be painted deny as well (they would otherwise
|
||||
// be painted allow because they are clear and have a direct path to allow1).
|
||||
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
|
||||
@ -711,8 +711,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
|
||||
Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
|
||||
shape, DT_FLOAT)
|
||||
.item;
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
|
||||
auto tl1w3 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
|
||||
Output tl1r2 =
|
||||
@ -748,7 +748,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||
@ -776,8 +776,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
|
||||
Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
|
||||
tl1w2.output_handle, shape, DT_FLOAT)
|
||||
.tensor;
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
|
||||
auto tl1w3 =
|
||||
ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
|
||||
Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
|
||||
@ -811,7 +811,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||
@ -835,8 +835,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
|
||||
Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
|
||||
shape, DT_FLOAT)
|
||||
.tensor;
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||
|
||||
// This tests that a allow-painted object node (tl2) will force an unpainted
|
||||
@ -863,7 +863,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
|
||||
@ -902,8 +902,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
|
||||
Output tl3r1 =
|
||||
ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
|
||||
.tensor;
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl3r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl3r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -922,7 +922,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
|
||||
const char* type_key = "element_dtype";
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
|
||||
@ -967,22 +967,25 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
|
||||
tensorflow::Input shape = {32, 32};
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1);
|
||||
Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
|
||||
auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
||||
auto tl1w1 = ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, gry1);
|
||||
auto _gry1 = tensorflow::ops::AsNodeOut(s, gry1);
|
||||
auto tl1w1 =
|
||||
ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, infer1);
|
||||
auto _infer1 = tensorflow::ops::AsNodeOut(s, infer1);
|
||||
auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
|
||||
auto builder =
|
||||
tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
|
||||
tensorflow::Node* func1_op;
|
||||
TF_CHECK_OK(
|
||||
builder.Input(_tl1w1_handle).Input(_gry1).Finalize(s.graph(), &func1_op));
|
||||
TF_CHECK_OK(builder.Input(_tl1w1_handle)
|
||||
.Input(_infer1)
|
||||
.Finalize(s.graph(), &func1_op));
|
||||
Output func1_handle(func1_op, 0);
|
||||
Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
|
||||
shape, DT_FLOAT)
|
||||
.tensor;
|
||||
auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
|
||||
auto tl2w1 = ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, gry1);
|
||||
auto tl2w1 =
|
||||
ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, infer1);
|
||||
Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
|
||||
tl2w1.output_handle, shape, DT_FLOAT)
|
||||
.tensor;
|
||||
@ -1004,7 +1007,7 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
|
||||
const char* type_key = "element_dtype";
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
|
||||
@ -1069,7 +1072,7 @@ TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, EluOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-5, 5, 1.0e-3, 1.0e-3,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Elu(scope, input);
|
||||
@ -1077,7 +1080,7 @@ TEST_F(AutoMixedPrecisionTest, EluOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, ErfOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-5, 5, 1.0e-3, -1,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Erf(scope, input);
|
||||
@ -1085,7 +1088,7 @@ TEST_F(AutoMixedPrecisionTest, ErfOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, ErfcOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-5, 5, 1.0e-3, -1,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Erfc(scope, input);
|
||||
@ -1093,7 +1096,7 @@ TEST_F(AutoMixedPrecisionTest, ErfcOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, InvOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
0.01, 10, -1, 1.0e-3,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Inv(scope, input);
|
||||
@ -1101,7 +1104,7 @@ TEST_F(AutoMixedPrecisionTest, InvOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, LogOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
0.01, 10, 1.0e-3, 2.0e-3,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Log(scope, input);
|
||||
@ -1109,7 +1112,7 @@ TEST_F(AutoMixedPrecisionTest, LogOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, Log1pOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-0.99, 9, 1.0e-3, 5.0e-3,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Log1p(scope, input);
|
||||
@ -1117,7 +1120,7 @@ TEST_F(AutoMixedPrecisionTest, Log1pOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-8, 8, -1, 1.0e-2,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::LogSoftmax(scope, input);
|
||||
@ -1125,7 +1128,7 @@ TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
0.01, 10, -1, 1.0e-3,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Reciprocal(scope, input);
|
||||
@ -1133,7 +1136,7 @@ TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-5, 5, 1.0e-3, -1,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Sigmoid(scope, input);
|
||||
@ -1141,7 +1144,7 @@ TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-8, 8, 2.0e-3, -1,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Softmax(scope, input);
|
||||
@ -1149,7 +1152,7 @@ TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-5, 5, 1.0e-3, 1.0e-3,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Softplus(scope, input);
|
||||
@ -1157,7 +1160,7 @@ TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, SqrtOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
0, 10, 1.0e-3, 1.0e-3,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Sqrt(scope, input);
|
||||
@ -1165,7 +1168,7 @@ TEST_F(AutoMixedPrecisionTest, SqrtOp) {
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, TanhOp) {
|
||||
TestSimpleUnaryGrayOp(
|
||||
TestSimpleUnaryInferOp(
|
||||
-5, 5, 1.0e-3, -1,
|
||||
[](const tensorflow::Scope& scope, Output input) -> Output {
|
||||
return ops::Tanh(scope, input);
|
||||
@ -1229,16 +1232,16 @@ TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
|
||||
TEST_F(AutoMixedPrecisionMklTest, Simple) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||
Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
|
||||
Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
|
||||
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
||||
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
|
||||
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk3);
|
||||
Output deny2 = ops::Log(s.WithOpName("deny2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
|
||||
Output deny3 = ops::SparseMatMul(s.WithOpName("deny3"), clr4, clr4);
|
||||
Output clr5 = ops::Relu(s.WithOpName("clr5"), deny3);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -1255,16 +1258,16 @@ TEST_F(AutoMixedPrecisionMklTest, Simple) {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Ta").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Tb").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Ta").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Tb").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
@ -1294,8 +1297,8 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
||||
Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
|
||||
shape, DT_FLOAT)
|
||||
.item;
|
||||
Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output infer1 = ops::Mul(s.WithOpName("infer1"), tl1r1, tl1r1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
|
||||
auto tl1w3 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
|
||||
Output tl1r2 =
|
||||
@ -1335,7 +1338,7 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
|
@ -36,8 +36,8 @@ namespace internal {
|
||||
// dynamically determined.
|
||||
constexpr int64 kTensorMaxSize = 64;
|
||||
|
||||
// All the nodes that should be blacklisted and not swapped.
|
||||
bool IsBlacklisted(const NodeDef& node) {
|
||||
// All the nodes that should be denylisted and not swapped.
|
||||
bool IsDenylisted(const NodeDef& node) {
|
||||
return
|
||||
// Collective ops should not be swapped.
|
||||
IsCollective(node) ||
|
||||
@ -94,8 +94,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
|
||||
bool* is_candidate) {
|
||||
*is_candidate = false;
|
||||
|
||||
// Make sure we are not a blacklisted op.
|
||||
if (IsBlacklisted(node)) {
|
||||
// Make sure we are not a denylisted op.
|
||||
if (IsDenylisted(node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -215,7 +215,7 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
|
||||
|
||||
// Checks if a node is a candidate to pin to Host.
|
||||
// The rough algorithm is as follows:
|
||||
// 1] Check if node is blacklisted.
|
||||
// 1] Check if node is denylisted.
|
||||
// 2] Check if node can run on Host.
|
||||
// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
|
||||
// ints, and pinned to Host).
|
||||
@ -230,7 +230,7 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
|
||||
}
|
||||
|
||||
// Skip these node types.
|
||||
if (IsBlacklisted(node)) {
|
||||
if (IsDenylisted(node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
@ -176,6 +177,61 @@ class BasicBatchScheduler : public BatchScheduler<TaskType> {
|
||||
// parameter.
|
||||
int max_enqueued_batches = 10;
|
||||
|
||||
// If true, an input task (i.e., input of `BasicBatchScheduler::Schedule`)
|
||||
// with a large size (i.e., larger than the largest value of
|
||||
// `allowed_batch_sizes`) will be split into multiple smaller batch tasks
|
||||
// and possibly put into different batches for processing. If false, each
|
||||
// input task is put into one batch as a whole for processing.
|
||||
//
|
||||
// API note:
|
||||
// The value of this option doesn't affect processing output given the same
|
||||
// input; it affects implementation details as stated below:
|
||||
// 1. Improve batching efficiency by eliminating unnecessary padding in the
|
||||
// following scenario: when an open batch has M slots while an input of size
|
||||
// N is scheduled (M < N), the input can be split to fill remaining slots
|
||||
// of an open batch as opposed to padding.
|
||||
// 2.`max_batch_size` specifies the limit of input and
|
||||
// `max_execution_batch_size` specifies the limit of a task to be processed.
|
||||
// API user can give an input of size 128 when 'max_execution_batch_size'
|
||||
// is 32 -> implementation can split input of 128 into 4 x 32, schedule
|
||||
// concurrent processing, and then return concatenated results corresponding
|
||||
// to 128.
|
||||
bool enable_large_batch_splitting = false;
|
||||
|
||||
// `split_input_task_func` specifies how to split `input_task` into
|
||||
// `output_tasks`.
|
||||
//
|
||||
// `input_task`: a unit of task to be split.
|
||||
// `first_output_task_size`: task size of first output.
|
||||
// `max_batch_size`: Maximum size of each batch.
|
||||
// `output_tasks`: A list of output tasks after split.
|
||||
//
|
||||
// REQUIRED:
|
||||
// 1) All `output_tasks` should be non-empty tasks.
|
||||
// 2) Sizes of `output_tasks` add up to size of `input_task`.
|
||||
//
|
||||
// NOTE:
|
||||
// Instantiations of `TaskType` may vary, so it's up to caller to define
|
||||
// how (e.g., which members to access) to split input tasks.
|
||||
std::function<Status(std::unique_ptr<TaskType>* input_task,
|
||||
int first_output_task_size, int input_batch_size_limit,
|
||||
std::vector<std::unique_ptr<TaskType>>* output_tasks)>
|
||||
split_input_task_func;
|
||||
|
||||
// The maximum size of each enqueued batch (i.e., in `batches_`).
|
||||
//
|
||||
// The scheduler may form batches of any size between 1 and this number
|
||||
// (inclusive). If there is a need to quantize the batch sizes, i.e. only
|
||||
// submit batches whose size is in a small set of allowed sizes, that can be
|
||||
// done by adding padding in the process-batch callback.
|
||||
//
|
||||
// REQUIRES:
|
||||
// - If enable_large_batch_splitting is true, `max_execution_batch_size` is
|
||||
// less than or equal to `max_batch_size`.
|
||||
// - If enable_large_batch_splitting is false, `max_execution_batch_size` is
|
||||
// equal to `max_batch_size`.
|
||||
int max_execution_batch_size = 10;
|
||||
|
||||
// The following options are typically only overridden by test code.
|
||||
|
||||
// The environment to use.
|
||||
@ -231,6 +287,12 @@ Status BasicBatchScheduler<TaskType>::Create(
|
||||
options.batch_timeout_micros;
|
||||
shared_scheduler_queue_options.max_enqueued_batches =
|
||||
options.max_enqueued_batches;
|
||||
shared_scheduler_queue_options.enable_large_batch_splitting =
|
||||
options.enable_large_batch_splitting;
|
||||
shared_scheduler_queue_options.split_input_task_func =
|
||||
options.split_input_task_func;
|
||||
shared_scheduler_queue_options.max_execution_batch_size =
|
||||
options.max_execution_batch_size;
|
||||
std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue;
|
||||
TF_RETURN_IF_ERROR(shared_scheduler->AddQueue(shared_scheduler_queue_options,
|
||||
process_batch_callback,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user