Merge remote-tracking branch 'upstream/master' into xtensa-fusion-f1

This commit is contained in:
Advait Jain 2020-12-08 21:45:17 -08:00
commit 6b24408403
281 changed files with 6002 additions and 4504 deletions

View File

@ -61,6 +61,7 @@
* Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`.
* Added dynamic range quantization support for the BatchMatMul op.
* Added DEPTH_TO_SPACE support in Post training quantization.
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
only supports float32 input.
* TFLite Supports SingatureDef:

View File

@ -145,6 +145,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Add module aliases
if hasattr(_current_module, 'keras'):

View File

@ -155,6 +155,8 @@ if _running_from_pip_package():
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.

View File

@ -684,7 +684,10 @@ tf_cc_test(
name = "c_api_experimental_test",
size = "medium",
srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"],
data = [
"testdata/tf_record",
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
@ -704,6 +707,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -37,7 +37,9 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h"
@ -630,6 +632,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
} // namespace tensorflow
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
@ -743,3 +748,45 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable) {
opts->opts.validate_colocation_constraints = enable;
}
// Load a Pluggable Device library.
// On success, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is loaded
// for the first time.
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"PluggableDevice plugin functionality is not supported on mobile");
return nullptr;
#else
TF_Library* lib_handle = new TF_Library;
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<std::string, void*>* loaded_libs =
new std::unordered_map<std::string, void*>();
tensorflow::Env* env = tensorflow::Env::Default();
{
tensorflow::mutex_lock lock(mu);
auto it = loaded_libs->find(library_filename);
if (it != loaded_libs->end()) {
lib_handle->lib_handle = it->second;
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
}
return lib_handle;
}
#endif
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
delete lib_handle;
}

View File

@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable);
// Load the library specified by library_filename and register the pluggable
// device and related kernels present in that library. This function is not
// supported on embedded on mobile and embedded platforms and will fail if
// called.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, returns the newly created library handle and places OK in status.
// The caller owns the library handle.
//
// On failure, returns nullptr and places an error status in status.
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
const char* library_filename, TF_Status* status);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
TF_Library* lib_handle);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -234,5 +235,22 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
TF_DeleteTensor(tensor_1X6);
}
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
string lib_path =
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
"tensorflow", "c", "experimental", "stream_executor", "test",
"test_pluggable_device.so"));
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
}
} // namespace
} // namespace tensorflow

View File

@ -213,7 +213,11 @@ TF_Function* TF_GraphToFunctionWithControlOutputs(
TF_DeleteFunction(tf_function);
return nullptr;
}
tf_function->graph_with_debug_info = &fn_body->graph;
for (const Node* n : fn_body->graph.nodes()) {
tf_function->stack_traces[n->name()] = n->GetStackTrace();
}
return tf_function;
}

View File

@ -157,9 +157,7 @@ struct TF_DeviceList {
struct TF_Function {
tensorflow::FunctionDef fdef;
// Graph with nodes with debug stack traces.
const tensorflow::Graph* graph_with_debug_info = nullptr;
tensorflow::StackTracesMap stack_traces;
};
struct TF_ApiDefMap {

View File

@ -749,8 +749,8 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithDebugInfo(
function->fdef, function->graph_with_debug_info);
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
function->fdef, function->stack_traces);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,

View File

@ -111,11 +111,11 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Same as `AddFunctionDef`, and additionally saves a pointer to the Graph
// which has nodes containing stack traces for the nodes in `fdef`. Assumes
// `graph` is alive while the function is alive.
virtual Status AddFunctionDefWithDebugInfo(const FunctionDef& fdef,
const Graph* graph) = 0;
// Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
// the key of the function definition name (to be retrieved during function
// instantiation).
virtual Status AddFunctionDefWithStackTraces(
const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
// Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;

View File

@ -0,0 +1,17 @@
# Description:
# test for stream_executor
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_shared_object",
)
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_shared_object(
name = "test_pluggable_device.so",
srcs = ["test_pluggable_device.cc"],
visibility = ["//tensorflow/c:__subpackages__"],
deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
)

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
xla::StatusOr<pybind11::object> Bfloat16Dtype();
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_BFLOAT16_H_
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
params->platform->name = "GPU";
params->platform->type = "XGPU";
}

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
using tensorflow::errors::InvalidArgument;
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
// and with a focus on maintaining both source and binary compatibility.
@ -87,9 +88,25 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
TF_SetStatus(status, TF_OK, "");
}
#undef CASE
} // namespace
} // namespace tensorflow
namespace {
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
const tensorflow::AttrValue* attr =
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
if (attr == nullptr) {
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
"' has no attr named '", attr_name, "'.");
}
return attr;
}
} // namespace
void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
const char* attr_name,
const TF_DataType type,
@ -257,7 +274,81 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
cc_ctx->CtxFailure(s);
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
const char* attr_name,
int32_t* list_size,
int32_t* total_size,
TF_Status* status) {
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
if (!status->status.ok()) {
*list_size = -1;
*total_size = -1;
return;
}
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
*list_size = -1; \
*total_size = size_expr; \
break;
SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
SINGLE_CASE(kI, TF_ATTR_INT, -1);
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
#undef SINGLE_CASE
case tensorflow::AttrValue::kList:
*list_size = 0;
*total_size = -1;
#define LIST_CASE(field, attr_type, ...) \
if (attr->list().field##_size() > 0) { \
*list_size = attr->list().field##_size(); \
__VA_ARGS__; \
break; \
}
LIST_CASE(
s, TF_ATTR_STRING, *total_size = 0;
for (int i = 0; i < attr->list().s_size();
++i) { *total_size += attr->list().s(i).size(); });
LIST_CASE(i, TF_ATTR_INT);
LIST_CASE(f, TF_ATTR_FLOAT);
LIST_CASE(b, TF_ATTR_BOOL);
LIST_CASE(type, TF_ATTR_TYPE);
LIST_CASE(
shape, TF_ATTR_SHAPE, *total_size = 0;
for (int i = 0; i < attr->list().shape_size(); ++i) {
const auto& s = attr->list().shape(i);
*total_size += s.unknown_rank() ? 0 : s.dim_size();
});
LIST_CASE(tensor, TF_ATTR_TENSOR);
LIST_CASE(tensor, TF_ATTR_FUNC);
#undef LIST_CASE
break;
case tensorflow::AttrValue::kPlaceholder:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::kFunc:
*list_size = -1;
*total_size = -1;
break;
case tensorflow::AttrValue::VALUE_NOT_SET:
status->status =
InvalidArgument("Attribute '", attr_name, "' has no value set");
break;
}
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
const char* attr_name, \
c_type* val, TF_Status* status) { \
@ -269,10 +360,84 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
if (s.ok()) { \
*val = static_cast<c_type>(v); \
} \
} \
void TF_OpKernelConstruction_GetAttr##func##List( \
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
int max_vals, TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
status->status = \
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
if (!status->status.ok()) return; \
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
}
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
DEFINE_TF_GETATTR(Float, float, float, "float", f)
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)
void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
const char* attr_name, char* value,
size_t max_length,
TF_Status* status) {
std::string v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
if (max_length <= 0) {
return;
}
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
}
void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
const char* attr_name,
char** values, size_t* lengths,
int max_values, void* storage,
size_t storage_size,
TF_Status* status) {
std::vector<std::string> v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(v.size()));
char* p = static_cast<char*>(storage);
for (int i = 0; i < len; ++i) {
const std::string& s = v[i];
values[i] = p;
lengths[i] = s.size();
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
status->status = InvalidArgument(
"Not enough storage to hold the requested list of strings");
return;
}
memcpy(values[i], s.data(), s.size());
p += s.size();
}
}
bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
const char* attr_name, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
return cc_ctx->HasAttr(attr_name);
}
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);

View File

@ -184,6 +184,24 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
// Returns the step ID of the given context.
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);
// Get the list_size and total_size of the attribute `attr_name` of `oper`.
// list_size - the length of the list.
// total_size - total size of the list.
// (1) If attr_type == TF_ATTR_STRING
// then total_size is the cumulative byte size
// of all the strings in the list.
// (3) If attr_type == TF_ATTR_SHAPE
// then total_size is the number of dimensions
// of the shape valued attribute, or -1
// if its rank is unknown.
// (4) If attr_type == TF_ATTR_SHAPE
// then total_size is the cumulative number
// of dimensions of all shapes in the list.
// (5) Otherwise, total_size is undefined.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
int32_t* total_size, TF_Status* status);
// Interprets the named kernel construction attribute as a TF_DataType and
// places it into *val. *status is set to TF_OK.
//
@ -202,6 +220,112 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
TF_Status* status);
// Interprets the named kernel construction attribute as int64_t and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// int64, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
TF_Status* status);
// Interprets the named kernel construction attribute as float and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// float, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
TF_Status* status);
// Interprets the named kernel construction attribute as bool and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// bool, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
TF_Status* status);
// Interprets the named kernel construction attribute as string and
// places it into *val. `val` must
// point to an array of length at least `max_length` (ideally set to
// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
// attr_name, list_size, total_size)). *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// string, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
size_t max_length, TF_Status* status);
// Interprets the named kernel construction attribute as a TF_DataType array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as int32_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as int64_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as float array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as bool array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
int max_vals, TF_Status* status);
// Interprets the named kernel construction attribute as string array and fills
// in `vals` and `lengths`, each of which must point to an array of length at
// least `max_values`. *status is set to TF_OK. The elements of values will
// point to addresses in `storage` which must be at least `storage_size` bytes
// in length. Ideally, max_values would be set to list_size and `storage` would
// be at least total_size, obtained from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
size_t* lengths, int max_values, void* storage, size_t storage_size,
TF_Status* status);
// Return true if the kernel construction has the attr_name
TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);
// Returns the unique operation name for this OpKernel.
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
TF_OpKernelConstruction* ctx);

View File

@ -161,6 +161,337 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
ASSERT_TRUE(delete_called);
}
// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
// Registers two ops, each with a single attribute called 'Attr'.
// The attribute in one op will have a type 'type', the other
// will have list(type).
#define ATTR_TEST_REGISTER_OP(name, type) \
REGISTER_OP("TestKernelAttr" #name) \
.Attr("Attr: " #type) \
.SetShapeFn(tensorflow::shape_inference::UnknownShape); \
REGISTER_OP("TestKernelAttr" #name "List") \
.Attr("Attr: list(" #type ")") \
.SetShapeFn(tensorflow::shape_inference::UnknownShape)
ATTR_TEST_REGISTER_OP(String, string);
ATTR_TEST_REGISTER_OP(Int, int);
ATTR_TEST_REGISTER_OP(Float, float);
ATTR_TEST_REGISTER_OP(Bool, bool);
ATTR_TEST_REGISTER_OP(Type, type);
#undef ATTR_TEST_REGISTER_OP
// Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
do { \
int32_t list_size, total_size; \
TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \
&total_size, status); \
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \
EXPECT_EQ(expected_list_size, list_size); \
EXPECT_EQ(expected_total_size, total_size); \
} while (0)
typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
class TestKernelAttr : public ::testing::Test {
public:
TestKernelAttr() {}
~TestKernelAttr() {}
std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
AttrValue v, Status* status) {
NodeDef def;
def.set_op(op_name);
def.set_name("FakeNode");
def.set_device("FakeDevice");
(*def.mutable_attr())["Attr"] = v;
return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
status);
}
void SetAttr(MyCreateFuncWithAttr MyCreateFuncAttr, const char* op_name,
AttrValue& v) {
TF_KernelBuilder* builder = TF_NewKernelBuilder(
op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder("FakeNode", builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernelWithAttr(op_name, v, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
kernel->Compute(nullptr);
ASSERT_TRUE(delete_called);
}
};
TEST_F(TestKernelAttr, String) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
std::unique_ptr<char[]> val(new char[5]);
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ 5);
TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
/*max_length*/ 5, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_s("bunny");
SetAttr(my_create_func, "TestKernelAttrString", v);
}
TEST_F(TestKernelAttr, StringList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
std::vector<string> list = {"bugs", "bunny", "duck"};
int list_total_size = 0;
for (const auto& s : list) {
list_total_size += s.size();
}
TF_Status* status = TF_NewStatus();
std::unique_ptr<char*[]> values(new char*[list.size()]);
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
std::unique_ptr<char[]> storage(new char[list_total_size]);
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
/*expected_total_size*/ list_total_size);
TF_OpKernelConstruction_GetAttrStringList(
ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
list_total_size, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (size_t i = 0; i < list.size(); ++i) {
EXPECT_EQ(list[i].size(), lens[i]) << i;
EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
<< i;
}
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<StringPiece>({"bugs", "bunny", "duck"});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrStringList", v);
}
TEST_F(TestKernelAttr, Int) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
int64_t val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1234, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_i(1234);
SetAttr(my_create_func, "TestKernelAttrInt", v);
}
TEST_F(TestKernelAttr, IntList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const int64_t list[] = {1, 2, 3, 4};
const size_t list_size = TF_ARRAYSIZE(list);
int64_t values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<int64>({1, 2, 3, 4});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrIntList", v);
}
TEST_F(TestKernelAttr, Float) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
float val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FLOAT_EQ(2.718, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_f(2.718);
SetAttr(my_create_func, "TestKernelAttrFloat", v);
}
TEST_F(TestKernelAttr, FloatList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const float list[] = {1.414, 2.718, 3.1415};
const size_t list_size = TF_ARRAYSIZE(list);
float values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<float>({1.414, 2.718, 3.1415});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrFloatList", v);
}
TEST_F(TestKernelAttr, Bool) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
unsigned char val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_b(1);
SetAttr(my_create_func, "TestKernelAttrBool", v);
}
TEST_F(TestKernelAttr, BoolList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const unsigned char list[] = {1, 0, 1, 0};
const size_t list_size = TF_ARRAYSIZE(list);
unsigned char values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in = gtl::ArraySlice<bool>({1, 0, 1, 0});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrBoolList", v);
}
TEST_F(TestKernelAttr, Type) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
TF_DataType val;
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(TF_FLOAT, val);
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
v.set_type(DT_FLOAT);
SetAttr(my_create_func, "TestKernelAttrType", v);
}
TEST_F(TestKernelAttr, TypeList) {
auto my_create_func = [](TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
const size_t list_size = TF_ARRAYSIZE(list);
TF_DataType values[list_size];
TF_Status* status = TF_NewStatus();
EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
/*expected_total_size*/ -1);
TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_TRUE(
std::equal(std::begin(list), std::end(list), std::begin(values)));
TF_DeleteStatus(status);
return static_cast<void*>(s);
};
AttrValue v;
auto attr_in =
gtl::ArraySlice<DataType>({DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128});
SetAttrValue(attr_in, &v);
SetAttr(my_create_func, "TestKernelAttrTypeList", v);
}
#undef EXPECT_TF_SIZE
class DummyDevice : public DeviceBase {
public:
explicit DummyDevice(Env* env) : DeviceBase(env) {}

View File

@ -65,6 +65,7 @@ MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(OrOp);
MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);
@ -81,6 +82,7 @@ MAP_HLO_TO_LHLO(SqrtOp);
MAP_HLO_TO_LHLO(SubOp);
MAP_HLO_TO_LHLO(TanhOp);
MAP_HLO_TO_LHLO(TransposeOp);
MAP_HLO_TO_LHLO(XorOp);
#undef MAP_HLO_TO_LHLO

View File

@ -481,6 +481,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types,
@ -580,6 +589,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
loc, result_types, args, b);
}
} // namespace impl
struct HloOpToStdScalarOp {

View File

@ -629,6 +629,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::MulOp>,
HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::NotOp>,
HloToLhloOpConverter<mhlo::OrOp>,
HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>,
@ -644,6 +645,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloOpConverter<mhlo::TransposeOp>,
HloToLhloOpConverter<mhlo::XorOp>,
HloToLhloReduceOpConverter,
HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter,

View File

@ -927,12 +927,14 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::FloorOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
PointwiseToLinalgConverter<lmhlo::LogOp>,
PointwiseToLinalgConverter<lmhlo::MaxOp>,
PointwiseToLinalgConverter<lmhlo::MinOp>,
PointwiseToLinalgConverter<lmhlo::MulOp>,
PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::NotOp>,
PointwiseToLinalgConverter<lmhlo::OrOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
@ -945,7 +947,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
PointwiseToLinalgConverter<lmhlo::XorOp>,
ReduceConverter,
ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>,
@ -1042,12 +1044,14 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::NotOp, false>,
PointwiseToLinalgConverter<mhlo::OrOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
@ -1055,11 +1059,12 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
PointwiseToLinalgConverter<mhlo::SignOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
PointwiseToLinalgConverter<mhlo::XorOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);

View File

@ -42,11 +42,11 @@ namespace {
sep fn(SqrtOp) sep fn(TanhOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
fn(AddOp) sep fn(AndOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) \
sep fn(MaxOp) sep fn(MinOp) sep fn(MulOp) sep fn(OrOp) sep fn(PowOp) \
sep fn(RemOp) sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \

View File

@ -316,6 +316,20 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// CHECK-LABEL: func @and
func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -389,6 +403,20 @@ func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
// -----
// CHECK-LABEL: func @or
func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -480,7 +508,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// CHECK-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
@ -492,6 +521,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// -----
// CHECK-LABEL: func @xor
func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// Dynamic shape binary element-wise operation.
// CHECK-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {

View File

@ -194,6 +194,30 @@ func @integer_and(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: func @integer_or
func @integer_or(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: or
%0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @integer_xor
func @integer_xor(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: xor
%0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {

View File

@ -509,7 +509,7 @@ Operation* BuildVariableOp(const tflite::TensorT& tensor,
return op.getOperation();
}
auto op = builder.create<tfl::ConstOp>(loc, value);
if (!tensor.quantization->min.empty()) {
if (tensor.quantization && !tensor.quantization->min.empty()) {
if (auto stats_op =
ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) {
return stats_op;

View File

@ -1977,6 +1977,7 @@ cc_library(
hdrs = ["utils/bridge_logger.h"],
deps = [
":dump_mlir_util",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",

View File

@ -5028,6 +5028,26 @@ Table initializer that takes two tensors for keys and values respectively.
TF_DerivedOperandTypeAttr Tkey = TF_DerivedOperandTypeAttr<1>;
}
def TF_InplaceAddOp : TF_Op<"InplaceAdd", [AllTypesMatch<["x", "y"]>, NoSideEffect]> {
let summary = "Adds v into specified rows of x.";
let description = [{
Computes y = x; y[i, :] += v; return y.
}];
let arguments = (ins
TF_Tensor:$x,
TF_Int32Tensor:$i,
TF_Tensor:$v
);
let results = (outs
TF_Tensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
let summary = "Updates specified rows 'i' with values 'v'.";
@ -5374,6 +5394,37 @@ def TF_IteratorV2Op : TF_Op<"IteratorV2", []> {
);
}
def TF_KthOrderStatisticOp : TF_Op<"KthOrderStatistic", [NoSideEffect]> {
let summary = "Computes the Kth order statistic of a data set. The current";
let description = [{
implementation uses a binary search requiring exactly 32 passes over
the input data. The running time is linear with respect to input
size. The median-of-medians algorithm is probably faster, but is
difficult to implement efficiently in XLA. The implementation imposes
a total ordering on floats. The ordering is consistent with the usual
partial order. Positive NaNs are greater than positive
infinity. Negative NaNs are less than negative infinity. NaNs with
distinct payloads are treated as distinct. Subnormal numbers are
preserved (not flushed to zero). Positive infinity is greater than all
numbers. Negative infinity is less than all numbers. Positive is
greater than negative zero. There are less than k values greater than
the kth order statistic. There are at least k values greater than or
equal to the Kth order statistic. The semantics are not the same as
top_k_unique.
}];
let arguments = (ins
TF_Float32Tensor:$input,
I64Attr:$k
);
let results = (outs
TF_Float32Tensor:$output
);
}
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
let summary = "L2 Loss.";
@ -6505,6 +6556,27 @@ iterator in `iterator` to the first element of `dataset`.
let results = (outs);
}
def TF_MakeUniqueOp : TF_Op<"MakeUnique", [NoSideEffect]> {
let summary = [{
Make all elements in the non-Batch dimension unique, but \"close\" to
}];
let description = [{
their initial value. Never returns a sub-normal number. Never returns
zero. The sign of each input element is always identical to the sign
of the corresponding output element. Behavior for infinite elements is
undefined. Behavior for subnormal elements is undefined.
}];
let arguments = (ins
TF_Float32Tensor:$input
);
let results = (outs
TF_Float32Tensor:$output
);
}
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
let summary = [{
Multiply the matrix "a" by the matrix "b".
@ -15234,6 +15306,36 @@ array([[1, 2, 3, 1, 2, 3],
let hasFolder = 1;
}
def TF_TopKUniqueOp : TF_Op<"TopKUnique", [NoSideEffect]> {
let summary = "Returns the TopK unique values in the array in sorted order.";
let description = [{
The running time is proportional to the product of K and the input
size. Sorting the whole array is more efficient for sufficiently large
values of K. The median-of-medians algorithm is probably faster, but
difficult to implement efficiently in XLA. If there are fewer than K
unique numbers (not NANs), the results are padded with negative
infinity. NaNs are never returned. Subnormal numbers are flushed to
zero. If an element appears at multiple indices, the highest index is
returned. If a TopK element never appears in the input due to padding
values, the indices are padded with negative one. If a padding value
appears in the input and padding is needed, the highest index of the
padding value will be returned. The semantics are not the same as
kth_order_statistic.
}];
let arguments = (ins
TF_Float32Tensor:$input,
I64Attr:$k
);
let results = (outs
TF_Float32Tensor:$topk,
TF_Int32Tensor:$topk_indices
);
}
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
let summary = [{
Finds values and indices of the `k` largest elements for the last dimension.
@ -15269,6 +15371,29 @@ If two elements are equal, the lower-index element appears first.
let verifier = [{ return Verify(*this); }];
}
def TF_TopKWithUniqueOp : TF_Op<"TopKWithUnique", [NoSideEffect]> {
let summary = "Returns the TopK values in the array in sorted order.";
let description = [{
This is a combination of MakeUnique and TopKUnique. The returned top-K will
have its lower bits replaced by iota, thus it will be close to the original
value but not exactly the same. The running time is proportional to the product
of K and the input size. NaNs are never returned. Subnormal numbers are flushed
to zero.
}];
let arguments = (ins
TF_Float32Tensor:$input,
I64Attr:$k
);
let results = (outs
TF_Float32Tensor:$topk,
TF_Int32Tensor:$topk_indices
);
}
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
let summary = "Shuffle dimensions of x according to a permutation.";

View File

@ -532,7 +532,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
auto ranked_type = type.dyn_cast<RankedTensorType>();
if (!ranked_type) return {};
auto output_type = getType().cast<ShapedType>();
// DenseIntElementsAttr::get requires the output type be ranked with static
// shape.
auto output_type = getType().dyn_cast<RankedTensorType>();
if (!output_type || !output_type.hasStaticShape()) return {};
int32_t rank = ranked_type.getRank();
return DenseIntElementsAttr::get(output_type, rank);
}

View File

@ -904,6 +904,20 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
return %0 : tensor<i32>
}
// CHECK-LABEL: testRankOfRankedTensorUnrankedOutput
func @testRankOfRankedTensorUnrankedOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<*xi32> {
// Regression test to make sure we don't crash in this case.
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// CHECK-LABEL: testRankOfRankedTensorDynamicShapeOutput
func @testRankOfRankedTensorDynamicShapeOutput(%arg0 : tensor<4x3x2xf32>) -> tensor<?xi32> {
// Regression test to make sure we don't crash in this case.
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// CHECK-LABEL: @foldFill
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>

View File

@ -1627,8 +1627,10 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
return success();
}
int64_t producer = producer_or.ValueOrDie();
// TODO(jpienaar): Clean up propagate_caller_callee_constants if it is no
// longer needed.
ShapeInference context(producer, module.getContext(),
/*propagate_caller_callee_constants=*/true);
/*propagate_caller_callee_constants=*/false);
if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
context.enqueue(main);
for (auto func : module.getOps<FuncOp>()) context.enqueue(func);

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
#include <atomic>
#include "absl/strings/str_split.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Operation.h" // from @llvm-project
@ -23,17 +26,30 @@ limitations under the License.
namespace tensorflow {
// Counter is used as a prefix for filenames.
static std::atomic<int> log_counter(0);
BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
bool print_after_only_on_change)
: mlir::PassManager::IRPrinterConfig(print_module_scope,
print_after_only_on_change) {}
print_after_only_on_change) {
const char* log_pass_patterns = getenv("MLIR_BRIDGE_LOG_PASS_PATTERNS");
if (log_pass_patterns) {
log_pass_patterns_ =
absl::StrSplit(log_pass_patterns, ',', absl::SkipWhitespace());
}
}
// Logs op to file with name of format `mlir_bridge-pass_name-file_suffix.mlir`.
// Logs op to file with name of format
// `<log_counter>_mlir_bridge_<pass_name>_<file_suffix>.mlir`.
inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
mlir::Pass* pass, mlir::Operation* op,
llvm::StringRef file_suffix) {
std::string name =
llvm::formatv("mlir_bridge_{0}_{1}", pass->getName(), file_suffix).str();
std::string pass_name = pass->getName().str();
// Add 4-digit counter as prefix so the order of the passes is obvious.
std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++,
pass_name, file_suffix);
std::unique_ptr<llvm::raw_ostream> os;
std::string filepath;
@ -44,13 +60,30 @@ inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
mlir::Operation* operation,
PrintCallbackFn print_callback) {
Log(print_callback, pass, operation, "before");
if (should_print(pass)) Log(print_callback, pass, operation, "before");
}
void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
mlir::Operation* operation,
PrintCallbackFn print_callback) {
Log(print_callback, pass, operation, "after");
if (should_print(pass)) Log(print_callback, pass, operation, "after");
}
bool BridgeLoggerConfig::should_print(mlir::Pass* pass) {
if (log_pass_patterns_.empty()) return true;
std::string pass_name = pass->getName().str();
for (const auto& pattern : log_pass_patterns_) {
if (pass_name.find(pattern) != std::string::npos) {
// pattern matches pass
return true;
}
}
// no pattern matches pass
VLOG(2) << "Not logging pass " << pass_name
<< " because it does not match any pattern in "
"MLIR_BRIDGE_LOG_PASS_PATTERNS";
return false;
}
void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) {

View File

@ -23,7 +23,11 @@ limitations under the License.
namespace tensorflow {
// Logger for logging/dumping MLIR modules before and after passes in bridge
// targeting TPUs.
// targeting TPUs. The passes being logged can be restricted via environment
// variable `MLIR_BRIDGE_LOG_PASS_PATTERNS` which is interpreted as a comma-
// separated list of strings, and only passes whose name contains any of those
// strings as a substring are logged (no regex support). If
// `MLIR_BRIDGE_LOG_PASS_PATTERNS` is not defined, then all passes are logged.
class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
public:
explicit BridgeLoggerConfig(bool print_module_scope = false,
@ -42,6 +46,14 @@ class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
// with the stream to dump into.
void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
PrintCallbackFn print_callback) override;
private:
bool should_print(mlir::Pass *pass);
// Only print passes that match any of these patterns. A pass matches a
// pattern if its name contains the pattern as a substring. If
// `log_pass_patterns_` is empty, print all passes.
std::vector<std::string> log_pass_patterns_;
};
// Logger for logging/dumping pass pipeline timings after completion.

View File

@ -516,6 +516,11 @@ class ConvertToHloModule {
//
// TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
LogicalResult Run() {
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
if (!main)
return module_.emitError(
"conversion requires module with `main` function");
for (auto func : module_.getOps<FuncOp>()) {
if (func.empty()) continue;
if (failed(RunOnFunction(func))) return failure();
@ -539,8 +544,11 @@ class ConvertToHloModule {
xla::XlaComputation* result);
::xla::HloModuleProto ConsumeMainProto() {
return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
.proto();
auto main = module_.lookupSymbol<mlir::FuncOp>("main");
// This is an invariant check as Run returns failure if there is no main
// function and so the main proto shouldn't be consumed in that case.
CHECK(main) << "requires module to have main function"; // Crash Ok.
return lowered_computation_[main].proto();
}
// Lower function call to HLO call instruction

View File

@ -174,6 +174,13 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
return %0: tensor<4xi32>
}
// CHECK-LABEL: func @bitwise_xor
func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: mhlo.xor
%0 = "tf.BitwiseXor"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0: tensor<4xi32>
}
// CHECK-LABEL: func @bitwise_and
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: mhlo.and

View File

@ -835,7 +835,14 @@ func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
}
// CHECK-LABEL: func @floordiv_unranked
func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NOT: tf.FloorDiv
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0: tensor<*xf32>
}
// CHECK-LABEL: func @floordiv_int
func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.FloorDiv
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
@ -894,7 +901,7 @@ func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
// CHECK-LABEL: func @floormod_unranked
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.FloorMod
// CHECK-NOT: tf.FloorMod
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
}

View File

@ -0,0 +1,7 @@
// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s
// CHECK: conversion requires module with `main`
func @non_main() {
%0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
return
}

View File

@ -114,7 +114,7 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
// Performs a substitution of FloorDiv, pseudo code below:
//
// return floor(div(x, y))
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
(HLO_FloorOp
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
[(IEEEFloatTensor $l)]>;
@ -166,7 +166,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
// Requires static shaped inputs to create constant splats and computation of
// broadcast attributes.
def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
(HLO_SelectOp
(HLOClient_BroadcastAndOp
(HLOClient_BroadcastCompareOp
@ -193,14 +193,15 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
//===----------------------------------------------------------------------===//
class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
: Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
(ToOp $l, $r, (BinBroadcastDimensions $l, $r)),
[(SignedIntTensor $l)]>;
foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp],
[TF_LogicalOrOp, HLOClient_BroadcastOrOp],
[TF_BitwiseAndOp, HLOClient_BroadcastAndOp],
[TF_BitwiseOrOp, HLOClient_BroadcastOrOp],
[TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in
[TF_BitwiseXorOp, HLOClient_BroadcastXorOp]] in
def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
//===----------------------------------------------------------------------===//

View File

@ -154,9 +154,11 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::IgammaOp>(),
TypeID::get<TF::IgammacOp>(),
TypeID::get<TF::IgammaGradAOp>(),
TypeID::get<TF::InplaceAddOp>(),
TypeID::get<TF::InTopKV2Op>(),
TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(),
TypeID::get<TF::KthOrderStatisticOp>(),
TypeID::get<TF::LRNOp>(),
TypeID::get<TF::LRNGradOp>(),
TypeID::get<TF::LeakyReluGradOp>(),
@ -170,6 +172,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::LogicalOrOp>(),
TypeID::get<TF::LogOp>(),
TypeID::get<TF::LowerBoundOp>(),
TypeID::get<TF::MakeUniqueOp>(),
TypeID::get<TF::MatMulOp>(),
TypeID::get<TF::MatrixDiagV3Op>(),
TypeID::get<TF::MatrixInverseOp>(),
@ -248,6 +251,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::TensorScatterAddOp>(),
TypeID::get<TF::TensorScatterSubOp>(),
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
TypeID::get<TF::TopKUniqueOp>(),
TypeID::get<TF::TopKWithUniqueOp>(),
TypeID::get<TF::TransposeOp>(),
TypeID::get<TF::TridiagonalSolveOp>(),
TypeID::get<TF::TruncateDivOp>(),

View File

@ -121,7 +121,7 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
// Run all HLO passes to produce an optimized module.
auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement(
std::move(hlo_module), backend->default_stream_executor(),
backend->memory_allocator(), optimize_xla_hlo);
optimize_xla_hlo, {backend->memory_allocator()});
TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(),
"running XLA pass pipeline");
std::unique_ptr<HloModule> optimized_hlo_module =

View File

@ -115,6 +115,16 @@ class ExecutableBuildOptions {
return *this;
}
// Thread pool for parallel compilation.
tensorflow::thread::ThreadPool* compile_thread_pool() const {
return compile_thread_pool_;
}
ExecutableBuildOptions& set_run_backend_only(
tensorflow::thread::ThreadPool* compile_thread_pool) {
compile_thread_pool_ = compile_thread_pool;
return *this;
}
private:
int device_ordinal_ = -1;
Shape result_layout_;
@ -128,6 +138,7 @@ class ExecutableBuildOptions {
absl::optional<DeviceAssignment> device_assignment_;
bool alias_passthrough_params_ = false;
bool run_backend_only_ = false;
tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
};
} // namespace xla

View File

@ -3347,6 +3347,8 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
// contant False if dimension is static.
// - Reduce: Convert to reduce or.
// - Constant: Convert to constant False.
// - Reshape, slice, transpose, pad:
// Convert into predicate type with same opcode.
// - Other ops: Not supported.
// Create the instruction for the new handle.
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
@ -3449,6 +3451,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
case HloOpcode::kBroadcast:
case HloOpcode::kConcatenate:
case HloOpcode::kReshape:
case HloOpcode::kPad:
break;
case HloOpcode::kGetDimensionSize: {
int64 dimension = instr_proto->dimensions(0);

View File

@ -26,8 +26,8 @@ static const char kCpuPlatformName[] = "cpu";
CpuDevice::CpuDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state)
: PjRtDevice(id, std::move(local_device_state),
/*device_kind=*/kCpuPlatformName) {}
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
/*device_kind=*/kCpuPlatformName) {}
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutorConfig config;
config.ordinal = i;
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
devices.push_back(std::move(device));
}
return std::make_unique<PjRtClient>(
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
kCpuName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
/*gpu_run_options=*/nullptr));
}
} // namespace xla

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace xla {
class CpuDevice : public PjRtDevice {
class CpuDevice : public PjRtStreamExecutorDevice {
public:
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
};

View File

@ -35,9 +35,9 @@ namespace xla {
namespace {
// A custom PjRtClient that overrides the device assignment method.
class GpuClient : public xla::PjRtClient {
class GpuClient : public xla::PjRtStreamExecutorClient {
public:
using xla::PjRtClient::PjRtClient;
using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
@ -55,7 +55,8 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
return assignment;
}
// Fallback to default global device assignment if we can't run locally.
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
num_partitions);
}
// Builds an xla::LocalClient for the GPU platform.
@ -225,9 +226,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
return result.first->second;
}
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
for (auto& local_device : local_device_states) {
int device_ordinal = local_device->device_ordinal();
const se::DeviceDescription& description =
@ -243,7 +244,7 @@ std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
Status BuildDistributedDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
std::vector<std::unique_ptr<PjRtDevice>>* devices,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
LocalTopologyProto local_topology;
local_topology.set_node_id(node_id);
@ -306,8 +307,8 @@ Status BuildDistributedDevices(
GpuDevice::GpuDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int node_id)
: PjRtDevice(id, std::move(local_device_state), std::move(device_kind),
node_id) {}
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
std::move(device_kind), node_id) {}
StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
bool asynchronous, const GpuAllocatorConfig& allocator_config,
@ -322,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
auto host_memory_allocator =
GetGpuHostAllocator(local_device_states.front()->executor());
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
if (distributed_client) {
TF_RETURN_IF_ERROR(BuildDistributedDevices(

View File

@ -25,7 +25,7 @@ limitations under the License.
namespace xla {
class GpuDevice : public PjRtDevice {
class GpuDevice : public PjRtStreamExecutorDevice {
public:
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int node_id);

View File

@ -26,8 +26,8 @@ static const char kInterpreterPlatformName[] = "interpreter";
InterpreterDevice::InterpreterDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state)
: PjRtDevice(id, std::move(local_device_state),
/*device_kind=*/kInterpreterPlatformName) {}
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
/*device_kind=*/kInterpreterPlatformName) {}
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
@ -41,7 +41,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
se::StreamExecutor* executor =
client->backend().stream_executor(0).ValueOrDie();
auto device_state = absl::make_unique<LocalDeviceState>(
@ -51,11 +51,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
devices.push_back(std::move(device));
return std::make_unique<PjRtClient>(
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
"interpreter", client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
/*gpu_run_options=*/nullptr));
}
} // namespace xla

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace xla {
class InterpreterDevice : public PjRtDevice {
class InterpreterDevice : public PjRtStreamExecutorDevice {
public:
InterpreterDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state);

View File

@ -114,21 +114,22 @@ limitations under the License.
namespace xla {
PjRtPlatformId PjRtDevice::platform_id() const {
PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
return client_->platform_id();
}
const std::string& PjRtDevice::platform_name() const {
const std::string& PjRtStreamExecutorDevice::platform_name() const {
return client_->platform_name();
}
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
const {
if (local_device_state_) {
return local_device_state_.get();
}
return InvalidArgument("Device %s is not a local device.", DebugString());
}
std::string PjRtDevice::DebugString() const {
std::string PjRtStreamExecutorDevice::DebugString() const {
return absl::StrCat(platform_name(), ":", id());
}
@ -153,14 +154,15 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
devices[replica].size(), replica, devices[0].size());
}
for (int partition = 0; partition < devices[replica].size(); ++partition) {
if (devices[0][0]->platform_id() !=
devices[replica][partition]->platform_id()) {
if (devices[0][0]->client()->platform_id() !=
devices[replica][partition]->client()->platform_id()) {
return InvalidArgument(
"Device assignment passed to Compile() must have devices of a "
"single kind, got %s for replica 0 partition 0 and %s for replica "
"%d partition %d.",
devices[0][0]->platform_name(),
devices[replica][partition]->platform_name(), replica, partition);
devices[0][0]->client()->platform_name(),
devices[replica][partition]->client()->platform_name(), replica,
partition);
}
xla_assignment(replica, partition) = devices[replica][partition]->id();
}
@ -182,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator {
}
};
PjRtClient::PjRtClient(
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
@ -193,7 +195,7 @@ PjRtClient::PjRtClient(
platform_name_(std::move(platform_name)),
client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)),
owned_devices_(std::move(devices)),
host_id_(host_id),
owned_allocator_(std::move(allocator)),
should_stage_host_to_device_transfers_(
@ -211,12 +213,14 @@ PjRtClient::PjRtClient(
host_memory_allocator_ = std::make_unique<CpuAllocator>();
}
for (const std::unique_ptr<PjRtDevice>& device : devices_) {
for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
owned_devices_) {
devices_.push_back(device.get());
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id();
if (device->IsLocalDevice()) {
int idx = device->local_device_id();
if (device->IsAddressable()) {
int idx = device->local_hardware_id();
if (idx >= local_devices_.size()) {
local_devices_.resize(idx + 1);
}
@ -230,13 +234,14 @@ PjRtClient::PjRtClient(
}
}
StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
return client_->backend().computation_placer()->AssignDevices(num_replicas,
num_partitions);
}
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
std::unique_ptr<HloCostAnalysis>
PjRtStreamExecutorClient::GetHloCostAnalysis() {
return absl::make_unique<HloCostAnalysis>(
client_->backend().compiler()->ShapeSizeBytesFunction());
}
@ -346,12 +351,13 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
return InvalidArgument("Can't make a buffer from an empty tuple");
}
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape, client->allocator(), local_device->device_ordinal()));
se_client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape, se_client->allocator(),
local_device->device_ordinal()));
if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
if (copy_stream == nullptr) {
@ -543,18 +549,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
<< " device: " << device->DebugString();
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostBuffer");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
if (shape.IsTuple()) {
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
}
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
int64 size = ShapeUtil::ByteSizeOf(shape);
TransferManager* transfer_manager = client()->backend().transfer_manager();
@ -708,20 +717,23 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
return py_buffer;
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) {
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
PjRtDevice* device) {
return CreateUninitializedBuffer(shape, device, nullptr);
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event) {
tensorflow::profiler::TraceMe traceme(
"PjRtClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
"PjRtStreamExecutorClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape,
@ -733,13 +745,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
definition_event);
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostLiteral");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
<< literal.shape().ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState());
TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
@ -792,7 +807,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
return py_buffer;
}
void PjRtClient::MakeCrossHostReceiveBuffers(
void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) {
@ -801,7 +816,9 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
return;
}
auto local_device_or = device->GetLocalDeviceState();
auto local_device_or =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->GetLocalDeviceState();
if (!local_device_or.ok()) {
notifier(local_device_or.status());
return;
@ -828,27 +845,30 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
}
// Transfer the given literal to the infeed queue of the given local device.
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
Status PjRtStreamExecutorDevice::TransferToInfeed(
const LiteralSlice& literal) const {
// Only support infeed to local device.
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal());
}
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
const Shape& shape) const {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal());
}
StatusOr<PjRtDevice*> PjRtClient::LookupLocalDevice(int local_device_id) const {
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
int local_hardware_id) const {
for (auto* device : local_devices_) {
if (local_device_id == device->local_device_id()) {
if (local_hardware_id == device->local_hardware_id()) {
return device;
}
}
return InvalidArgument("No matching device found for local_device_id %d",
local_device_id);
return InvalidArgument("No matching device found for local_hardware_id %d",
local_hardware_id);
}
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
@ -873,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() {
}
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
return client_->client()
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->client()
->backend()
.transfer_manager()
->GetByteSizeRequirement(on_device_shape_);
@ -919,7 +940,9 @@ StatusOr<std::shared_ptr<TrackedDeviceBuffer>> PjRtBuffer::Release(
// the final set of usage events.
events = device_buffer->LockUseAndTransferUsageEvents();
}
LocalDeviceState* local_device_state = device_->local_device_state();
LocalDeviceState* local_device_state =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state();
if (wait_for_operations_to_complete) {
// Block the host until all usage events have completed. Usage events
// dominate definition events, so this also waits for the buffer to be
@ -1080,7 +1103,9 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
}
ScopedHold device_buffer(this, ScopedHold::kUsage);
std::shared_ptr<HostValue> host_value;
LocalDeviceState* local_device = device_->local_device_state();
LocalDeviceState* local_device =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state();
se::Stream* stream = local_device->GetDeviceToHostStream();
const xla::Layout& host_layout =
layout.has_value() ? layout.value() : on_host_shape_.layout();
@ -1122,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
host_value->value = std::make_shared<Literal>(host_shape);
ShapedBuffer shaped_buffer =
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
stream, shaped_buffer, host_value->value.get(),
[host_value](Status done_status) {
host_value->status = done_status;
host_value->ready.Notify();
});
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->client()
->backend()
.transfer_manager()
->TransferLiteralFromDevice(stream, shaped_buffer,
host_value->value.get(),
[host_value](Status done_status) {
host_value->status = done_status;
host_value->ready.Notify();
});
auto usage_event = std::make_shared<BufferSequencingEvent>();
StatusOr<EventPool::Handle> event_or =
@ -1156,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
CopyToHostAsyncInternal(discard_cached_copy, layout));
if (host_value == nullptr) {
@ -1241,8 +1270,9 @@ PjRtBuffer::CopyToDeviceHelper(
// StallStreamOnError only makes sure the destination device is ok, so
// make sure that the src buffer remains valid until after any transfers
// have completed.
device_->local_device_state()->ThenRelease(transfer_stream,
src_device_buffer);
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state()
->ThenRelease(transfer_stream, src_device_buffer);
}
return copy_event_or.status();
}
@ -1265,14 +1295,20 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
return dst_device->client()->BufferFromHostBuffer(
literal->untyped_data(), literal->shape(),
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
dst_device);
}
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
dst_device->GetLocalDeviceState());
TF_ASSIGN_OR_RETURN(
LocalDeviceState * dst_local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
->GetLocalDeviceState());
LocalDeviceState* transfer_local_device =
client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
: dst_local_device;
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->EnqueueD2DTransfersOnSrcStream()
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state()
: dst_local_device;
CHECK_EQ(dst_local_device->allocation_model(),
transfer_local_device->allocation_model());
@ -1310,7 +1346,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
// alternative is to ensure, before freeing the buffer, that the compute
// stream is synchronized past the transfer, but it seems better to hold onto
// the buffer too long than to stall the compute stream.
RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
RecordUsage(std::move(src_device_buffer),
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state(),
transfer_local_device, event, transfer_stream,
/*prefer_to_retain_reference=*/true);
@ -1318,7 +1356,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
}
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
return client_->CopyToRemoteDevice(this, serialized_descriptor);
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->CopyToRemoteDevice(this, serialized_descriptor);
}
Status PjRtBuffer::BlockHostUntilReady() {
@ -1332,7 +1371,9 @@ Status PjRtBuffer::BlockHostUntilReady() {
}
device_buffer = device_buffer_;
}
LocalDeviceState* local_device_state = device_->local_device_state();
LocalDeviceState* local_device_state =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state();
std::unique_ptr<se::Stream> stream;
for (auto& event : device_buffer->definition_events()) {
if (!event->IsComplete()) {
@ -1378,9 +1419,13 @@ StatusOr<TupleHandle> MakeTupleHelper(
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
se::DeviceMemoryAllocator* allocator = client->allocator();
se::DeviceMemoryAllocator* allocator =
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
->client()
->backend()
.transfer_manager();
se::Stream* stream = local_device->host_to_device_stream();
TF_ASSIGN_OR_RETURN(
se::OwningDeviceMemory root_table_memory,
@ -1444,14 +1489,6 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
/*prefer_to_retain_reference=*/false);
return pjrt_buffer;
}
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
auto it = client.id_to_device().find(device_id);
CHECK(it != client.id_to_device().end())
<< "Unknown device id: " << device_id;
return it->second;
}
} // namespace
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
@ -1459,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client)
: client_(client),
device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
@ -1482,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
<< device_assignment_->ToString();
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(addressable_devices_.size(), client_->local_device_count())
CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
<< "Inconsistent local device count.";
num_partitions = device_assignment_->computation_count();
}
@ -1584,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
std::vector<ExecutionInput> execution_inputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
// Lift tuple_handle outside the conditional so that the event it returns is
// not destroyed until after the loop below that waits on events.
absl::optional<TupleHandle> tuple_handle;
@ -1607,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
execution_input.MutableBuffers()->end();
device_buffers[i].AddToInput(&input_iterator, iterator_end,
&execution_input, client_->allocator());
device_buffers[i].AddToInput(
&input_iterator, iterator_end, &execution_input,
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->allocator());
CHECK(input_iterator == iterator_end);
}
}
@ -1628,8 +1668,10 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const {
int device_ordinal = device->local_device_state()->device_ordinal();
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state()
->device_ordinal();
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
tensorflow::profiler::TraceMeConsumer activity(
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
run_id.ToInt());
@ -1765,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtDevice* device) const {
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
outputs.reserve(tuple_count);
@ -1802,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
if (device == nullptr) {
CHECK(device_assignment_ != nullptr);
const int device_id = (*device_assignment_)(replica, partition);
device = LookupDevice(*client_, device_id);
TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
device_assignment = device_assignment_;
} else {
CHECK(device_assignment_ == nullptr);
@ -1814,7 +1856,9 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
}
CHECK_EQ(device->host_id(), client_->host_id());
int device_ordinal = device->local_device_state()->device_ordinal();
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state()
->device_ordinal();
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
VLOG(3) << "Replica " << replica << ", partition " << partition
<< " mapped to device ordinal for execution: " << device_ordinal;
@ -1836,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
ScopedShapedBuffer result_buffer =
result_buffer_or_status.ConsumeValueOrDie();
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
se::Stream* stream = device_state->compute_stream();
StatusOr<EventPool::Handle> event_or =
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
@ -1922,7 +1966,9 @@ PjRtStreamExecutorExecutable::Execute(
const int replica = addressable_device_logical_ids_[i].replica;
const int partition = addressable_device_logical_ids_[i].partition;
PjRtDevice* device = addressable_devices_[i];
const LocalDeviceState& device_state = *device->local_device_state();
const LocalDeviceState& device_state =
*tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state();
device_state.execute_thread()->Schedule([&, replica, partition, i] {
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
run_id, options);
@ -2131,9 +2177,9 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
} // namespace
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
const XlaComputation& computation, CompileOptions options) {
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options;
if (!build_options.device_allocator()) {
@ -2153,14 +2199,15 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
num_partitions = 1;
} else {
if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
"device_assignment.";
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions()));
build_options.set_device_assignment(device_assignment);
}
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
<< build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count();
@ -2234,7 +2281,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition);
PjRtDevice* device = LookupDevice(*this, device_id);
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
if (device->host_id() != host_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
@ -2254,7 +2301,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal(
addressable_devices.front()->local_device_state()->device_ordinal());
addressable_devices.front()->local_hardware_id());
}
}

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
@ -67,16 +68,56 @@ class PjRtClient;
class PjRtDevice {
public:
explicit PjRtDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int host_id = 0)
virtual ~PjRtDevice() {}
// Return the client that owns this device.
virtual PjRtClient* client() const = 0;
// Whether client can issue command to this device.
virtual bool IsAddressable() const = 0;
// The ID of this device. IDs are unique among devices of this type
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
virtual int id() const = 0;
// The task ID of this device according to TpuTopology. This is not the same
// as PjRtClient::host_id() in a multi-task setting, where each client can see
// devices from all tasks, but only a subset of them are addressable and have
// the same task_id as the client.
virtual int host_id() const = 0;
// Opaque hardware ID, e.g., the CUDA device number, useful for identifying
// which GPU when interacting with non-JAX code. In general, not guaranteed to
// be dense, and -1 if undefined.
virtual int local_hardware_id() const = 0;
// A vendor-dependent string that uniquely identifies the kind of device,
// e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
// compatible compilation.
virtual const std::string& device_kind() const = 0;
virtual std::string DebugString() const = 0;
// Transfer the given literal to the infeed queue.
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
// Transfer and return a value of the given shape from the outfeed queue.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
};
class PjRtStreamExecutorDevice : public PjRtDevice {
public:
explicit PjRtStreamExecutorDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int host_id = 0)
: id_(id),
local_device_id_(
device_ordinal_(
local_device_state ? local_device_state->device_ordinal() : -1),
local_device_state_(std::move(local_device_state)),
host_id_(host_id),
device_kind_(std::move(device_kind)) {}
virtual ~PjRtDevice() {}
~PjRtStreamExecutorDevice() override {}
// Must set client exactly once.
void SetClient(PjRtClient* client) {
@ -84,14 +125,25 @@ class PjRtDevice {
client_ = client;
}
// Task ID. This is always 0 on single-task setup.
int host_id() const override { return host_id_; }
// Return `platform_id` from client.
PjRtPlatformId platform_id() const;
// Return `platform_name` from client.
const std::string& platform_name() const;
PjRtClient* client() const override { return client_; }
// The ID of this device. IDs are unique among devices of this type
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
int id() const { return id_; }
int id() const override { return id_; }
bool IsLocalDevice() const { return local_device_id_ != -1; }
bool IsAddressable() const override { return device_ordinal_ != -1; }
int local_device_id() const { return local_device_id_; }
int local_hardware_id() const override { return device_ordinal_; }
// If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns nullptr if the device is
@ -105,32 +157,21 @@ class PjRtDevice {
// is not local to this host.
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
// The ID of this device's host. This is always 0 on single-host platforms.
int host_id() const { return host_id_; }
// Return `platform_id` from client.
PjRtPlatformId platform_id() const;
// Return `platform_name` from client.
const std::string& platform_name() const;
// A vendor-dependent string that uniquely identifies the kind of device.
const std::string& device_kind() const { return device_kind_; }
const std::string& device_kind() const override { return device_kind_; }
virtual std::string DebugString() const;
PjRtClient* client() const { return client_; }
std::string DebugString() const override;
// Transfer the given literal to the infeed queue of the given localdevice.
virtual Status TransferToInfeed(const LiteralSlice& literal) const;
Status TransferToInfeed(const LiteralSlice& literal) const override;
// Transfer and return a value of the given shape from the outfeed of the
// given device.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
private:
const int id_;
const int local_device_id_; // -1 means not local.
const int device_ordinal_; // -1 means not local.
const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_;
const std::string device_kind_;
@ -178,86 +219,62 @@ class PjRtExecutable;
// alive as long as any of the other runtime objects are alive.
class PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
virtual ~PjRtClient() = default;
// TODO(zhangqiaorjc): Rename to task_id.
// Return the task id of this client. In single-task setting, always 0.
virtual int host_id() const = 0;
// Return the number of devices in the entire computation. In multi-headed
// client setting, some are addressable by this client, some are not. In a
// single-client setting, this is equal to the number of addressable devices.
virtual int device_count() const = 0;
// Return number of addressable devices. Addressable devices are those that
// the client can issue commands to.
virtual int addressable_device_count() const = 0;
// Return all devices in the entire computation, including addressable and
// non-addressable devices.
virtual absl::Span<PjRtDevice* const> devices() const = 0;
// TODO(zhangqiaorjc): Rename to addressable_devices.
// Return only addressable devices.
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
// Lookup any PjRtDevice for a given PjRtDevice::id().
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
// Return an addressable PjRtDevice for a given
// PjRtDevice::local_hardware_id().
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const = 0;
// Return an ID that identifies the platform (CPU/GPU/TPU).
virtual PjRtPlatformId platform_id() const = 0;
// Returns a string that identifies the platform (CPU/GPU/TPU).
virtual const std::string& platform_name() const = 0;
// Return a device-specific default device assignment, e.g., GPU and TPU may
// be different.
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const;
int device_count() const { return devices_.size(); }
int local_device_count() const { return local_devices_.size(); }
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
return devices_;
}
const std::vector<PjRtDevice*>& local_devices() const {
return local_devices_;
}
const std::map<int, PjRtDevice*>& id_to_device() const {
return id_to_device_;
}
int host_id() const { return host_id_; }
PjRtPlatformId platform_id() const { return platform_id_; }
const std::string& platform_name() const { return platform_name_; }
LocalDeviceState& device_state(int device_ordinal) const {
return *local_devices_.at(device_ordinal)->local_device_state();
}
// Return a local PjRtDevice for a given `local_device_id`.
virtual StatusOr<PjRtDevice*> LookupLocalDevice(int local_device_id) const;
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
// Generates a unique fingerprint for `executable`.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const {
return absl::optional<std::string>();
}
int num_replicas, int num_partitions) const = 0;
// Returns a backend-specific HLO cost analysis visitor.
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
// Compile `computation` with given `options`.
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
const XlaComputation& computation, CompileOptions options) = 0;
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const = 0;
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device);
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
const Shape& shape, PjRtDevice* device) = 0;
// Describes the semantics the caller to BufferFromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
@ -289,13 +306,13 @@ class PjRtClient {
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device);
const LiteralSlice& literal, PjRtDevice* device) = 0;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
@ -308,18 +325,140 @@ class PjRtClient {
// buffers will become ready until *all* of the sends have completed.
virtual void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier);
PjRtCrossHostRecvNotifier&& notifier) = 0;
virtual StatusOr<ChannelHandle> CreateChannelHandle() {
// Create ChannelHandles for XLA send/recv.
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
};
class PjRtStreamExecutorClient : public PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
~PjRtStreamExecutorClient() override = default;
int host_id() const override { return host_id_; }
int device_count() const override { return devices_.size(); }
int addressable_device_count() const override {
return local_devices_.size();
}
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
absl::Span<PjRtDevice* const> local_devices() const override {
return local_devices_;
}
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
auto it = id_to_device_.find(device_id);
if (it != id_to_device_.end()) {
return it->second;
}
return InvalidArgument("No matching device found for device_id %d",
device_id);
}
StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const override;
PjRtPlatformId platform_id() const override { return platform_id_; }
const std::string& platform_name() const override { return platform_name_; }
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
// Generates a unique fingerprint for `executable`.
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const override {
return absl::optional<std::string>();
}
// Returns a backend-specific HLO cost analysis visitor.
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) override;
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) override;
StatusOr<ChannelHandle> CreateChannelHandle() override {
return client()->CreateChannelHandle();
}
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
return client()->CreateDeviceToHostChannelHandle();
}
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
return client()->CreateHostToDeviceChannelHandle();
}
LocalDeviceState& device_state(int device_ordinal) const {
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
local_devices_.at(device_ordinal))
->local_device_state();
}
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
protected:
friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive(
@ -342,7 +481,9 @@ class PjRtClient {
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<PjRtDevice>> devices_;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
// Pointers to `owned_devices_`.
std::vector<PjRtDevice*> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices.
std::map<int, PjRtDevice*> id_to_device_;
// Local devices indexed by local device ordinal.
@ -509,7 +650,7 @@ class PjRtBuffer {
private:
friend class PjRtBuffer;
friend class PjRtClient;
friend class PjRtStreamExecutorClient;
// Helper struct that makes it possible to move a ScopedHold through a
// closure.
@ -769,7 +910,7 @@ class PjRtExecutable {
virtual PjRtClient* client() const = 0;
// Unique name for this executable, e.g., HloModule name.
virtual const string& name() const = 0;
virtual const std::string& name() const = 0;
virtual int num_replicas() const = 0;
@ -791,6 +932,7 @@ class PjRtExecutable {
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const = 0;
// An addressable_device is one which the client can issue commands to.
// addressable_devices()[i] is the Device to which
// addressable_device_logical_ids()[i] is assigned.
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
@ -833,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client);
~PjRtStreamExecutorExecutable() override = default;
PjRtClient* client() const override { return client_; }
PjRtStreamExecutorClient* client() const override { return client_; }
const string& name() const override;
const std::string& name() const override;
int num_replicas() const override {
return executables_[0]->build_options().num_replicas();
@ -898,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
}
private:
friend class PjRtClient;
friend class PjRtStreamExecutorClient;
// Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation.
Status SetUpDonation(bool tuple_inputs);
@ -933,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
// Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the
// executable itself.
PjRtClient* const client_;
PjRtStreamExecutorClient* const client_;
// One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_;
// Per-executable set of parameters that have any aliased buffers and thus

View File

@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
return Status::OK();
}
class PjRtTpuClient : public PjRtClient {
class PjRtTpuClient : public PjRtStreamExecutorClient {
public:
PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id);
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id);
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient {
const PjRtExecutable& executable) const override;
};
PjRtTpuClient::PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices,
int host_id)
: PjRtClient(kTpuName, client, std::move(devices), host_id,
/*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr) {}
PjRtTpuClient::PjRtTpuClient(
LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
/*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr) {}
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
@ -128,7 +129,8 @@ StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
num_partitions);
}
// Fallback to default global device assignment if we can't run locally.
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
num_partitions);
}
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
@ -152,10 +154,10 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
return absl::optional<std::string>(tpu_executable->fingerprint());
}
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
LocalClient* client,
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
tf_tpu::TpuTopologyExternal topology =
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();

View File

@ -26,14 +26,14 @@ limitations under the License.
namespace xla {
class PjRtTpuDevice : public PjRtDevice {
class PjRtTpuDevice : public PjRtStreamExecutorDevice {
public:
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
std::unique_ptr<LocalDeviceState> local_device_state,
int host_id, const std::array<int, 3>& coords,
std::string device_kind)
: PjRtDevice(core.Id(), std::move(local_device_state),
std::move(device_kind), host_id),
: PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
std::move(device_kind), host_id),
core_(core),
coords_(coords) {}

View File

@ -97,13 +97,13 @@ cc_library(
name = "types",
srcs = ["types.cc"],
hdrs = ["types.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":bfloat16",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@ -113,6 +113,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/core:lib",
"//tensorflow/python:bfloat16_lib",
"//third_party/py/numpy:headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
@ -158,42 +159,6 @@ cc_library(
],
)
cc_library(
name = "bfloat16",
srcs = ["bfloat16.cc"],
hdrs = ["bfloat16.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core/platform:bfloat16",
"//tensorflow/core/platform:logging",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@com_google_absl//absl/strings",
"@pybind11",
],
)
py_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.py"],
main = "bfloat16_test.py",
python_version = "PY3",
tags = ["no_oss"],
deps = [
":xla_client",
":xla_extension",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
] + xla_py_test_deps(),
)
cc_library(
name = "py_client",
srcs = [
@ -206,6 +171,7 @@ cc_library(
"py_client.h",
"py_executable.h",
],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -232,6 +198,7 @@ cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -263,6 +230,7 @@ cc_library(
name = "jax_jit",
srcs = ["jax_jit.cc"],
hdrs = ["jax_jit.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -292,6 +260,7 @@ cc_library(
name = "ops",
srcs = ["ops.cc"],
hdrs = ["ops.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -356,6 +325,7 @@ cc_library(
name = "outfeed_receiver_py",
srcs = ["outfeed_receiver_py.cc"],
hdrs = ["outfeed_receiver_py.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -379,12 +349,14 @@ cc_library(
name = "pytree",
srcs = ["pytree.cc"],
hdrs = ["pytree.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":types",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
@ -435,6 +407,7 @@ cc_library(
name = "xla_compiler",
srcs = ["xla_compiler.cc"],
hdrs = ["xla_compiler.h"],
compatible_with = [],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -481,7 +454,6 @@ pybind_extension(
features = ["-use_header_modules"],
module_name = "xla_extension",
deps = [
":bfloat16",
":dlpack",
":jax_jit",
":ops",
@ -534,6 +506,7 @@ pybind_extension(
# without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
# not require Tensorflow.
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
"//tensorflow/python:bfloat16_lib",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor:platform",
] + select({

File diff suppressed because it is too large Load Diff

View File

@ -1,440 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for the bfloat16 Python type."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import itertools
import math
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.xla.python import xla_client
bfloat16 = xla_client.bfloat16
def numpy_assert_allclose(a, b, **kwargs):
a = a.astype(np.float32) if a.dtype == bfloat16 else a
b = b.astype(np.float32) if b.dtype == bfloat16 else b
return np.testing.assert_allclose(a, b, **kwargs)
epsilon = float.fromhex("1.0p-7")
# Values that should round trip exactly to float and back.
FLOAT_VALUES = [
0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
float("inf"),
float("-inf"),
float("nan")
]
class Bfloat16Test(parameterized.TestCase):
"""Tests the non-numpy Python methods of the bfloat16 type."""
def testRoundTripToFloat(self):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, float(bfloat16(v)))
def testRoundTripNumpyTypes(self):
for dtype in [np.float16, np.float32, np.float64]:
np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
np.testing.assert_equal(
np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
def testRoundTripToInt(self):
for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
self.assertEqual(v, int(bfloat16(v)))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + dtype.__name__,
"dtype": dtype
} for dtype in [bfloat16, np.float16, np.float32, np.float64]))
def testRoundTripToNumpy(self, dtype):
for v in FLOAT_VALUES:
np.testing.assert_equal(v, bfloat16(dtype(v)))
np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
if dtype != bfloat16:
np.testing.assert_equal(
np.array(FLOAT_VALUES, dtype),
bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
def testStr(self):
self.assertEqual("0", str(bfloat16(0.0)))
self.assertEqual("1", str(bfloat16(1.0)))
self.assertEqual("-3.5", str(bfloat16(-3.5)))
self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", str(bfloat16(float("inf"))))
self.assertEqual("-inf", str(bfloat16(float("-inf"))))
self.assertEqual("nan", str(bfloat16(float("nan"))))
def testRepr(self):
self.assertEqual("0", repr(bfloat16(0)))
self.assertEqual("1", repr(bfloat16(1)))
self.assertEqual("-3.5", repr(bfloat16(-3.5)))
self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
self.assertEqual("inf", repr(bfloat16(float("inf"))))
self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
self.assertEqual("nan", repr(bfloat16(float("nan"))))
def testHash(self):
self.assertEqual(0, hash(bfloat16(0.0)))
self.assertEqual(0x3f80, hash(bfloat16(1.0)))
self.assertEqual(0x7fc0, hash(bfloat16(float("nan"))))
# Tests for Python operations
def testNegate(self):
for v in FLOAT_VALUES:
np.testing.assert_equal(-v, float(-bfloat16(v)))
def testAdd(self):
np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
# Test type promotion against Numpy scalar values.
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float16(2.25)))
self.assertEqual(np.float32, type(np.float16(3.5) + bfloat16(2.25)))
self.assertEqual(np.float32, type(bfloat16(3.5) + np.float32(2.25)))
self.assertEqual(np.float32, type(np.float32(3.5) + bfloat16(2.25)))
self.assertEqual(np.float64, type(bfloat16(3.5) + np.float64(2.25)))
self.assertEqual(np.float64, type(np.float64(3.5) + bfloat16(2.25)))
self.assertEqual(np.float64, type(bfloat16(3.5) + float(2.25)))
self.assertEqual(np.float64, type(float(3.5) + bfloat16(2.25)))
self.assertEqual(np.float32,
type(bfloat16(3.5) + np.array(2.25, np.float32)))
self.assertEqual(np.float32,
type(np.array(3.5, np.float32) + bfloat16(2.25)))
def testSub(self):
np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
np.testing.assert_equal(
float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
def testMul(self):
np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
def testDiv(self):
self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
np.testing.assert_equal(
float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
np.testing.assert_equal(
float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
def testLess(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
def testLessEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
def testGreater(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
def testGreaterEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
def testEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
def testNotEqual(self):
for v in FLOAT_VALUES:
for w in FLOAT_VALUES:
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
def testNan(self):
a = np.isnan(bfloat16(float("nan")))
self.assertTrue(a)
numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
a = np.array([bfloat16(1.34375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
b = np.array(
[bfloat16(1.3359375),
bfloat16(1.4375),
bfloat16(float("nan"))],
dtype=bfloat16)
numpy_assert_allclose(
a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
def testSort(self):
values_to_sort = np.float32(FLOAT_VALUES)
sorted_f32 = np.sort(values_to_sort)
sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))
np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
BinaryOp = collections.namedtuple("BinaryOp", ["op"])
UNARY_UFUNCS = [
np.negative, np.positive, np.absolute, np.fabs, np.rint, np.sign,
np.conjugate, np.exp, np.exp2, np.expm1, np.log, np.log10, np.log1p,
np.log2, np.sqrt, np.square, np.cbrt, np.reciprocal, np.sin, np.cos, np.tan,
np.arcsin, np.arccos, np.arctan, np.sinh, np.cosh, np.tanh, np.arcsinh,
np.arccosh, np.arctanh, np.deg2rad, np.rad2deg, np.floor, np.ceil, np.trunc
]
BINARY_UFUNCS = [
np.add, np.subtract, np.multiply, np.divide, np.logaddexp, np.logaddexp2,
np.floor_divide, np.power, np.remainder, np.fmod, np.heaviside, np.arctan2,
np.hypot, np.maximum, np.minimum, np.fmax, np.fmin, np.copysign
]
BINARY_PREDICATE_UFUNCS = [
np.equal, np.not_equal, np.less, np.greater, np.less_equal,
np.greater_equal, np.logical_and, np.logical_or, np.logical_xor
]
class Bfloat16NumPyTest(parameterized.TestCase):
"""Tests the NumPy integration of the bfloat16 type."""
def testDtype(self):
self.assertEqual(bfloat16, np.dtype(bfloat16))
def testDeepCopyDoesNotAlterHash(self):
# For context, see https://github.com/google/jax/issues/4651. If the hash
# value of the type descriptor is not initialized correctly, a deep copy
# can change the type hash.
dtype = np.dtype(bfloat16)
h = hash(dtype)
_ = copy.deepcopy(dtype)
self.assertEqual(h, hash(dtype))
def testArray(self):
x = np.array([[1, 2, 3]], dtype=bfloat16)
self.assertEqual(bfloat16, x.dtype)
self.assertEqual("[[1 2 3]]", str(x))
np.testing.assert_equal(x, x)
numpy_assert_allclose(x, x)
self.assertTrue((x == x).all())
def testComparisons(self):
x = np.array([401408, 7, -32], dtype=np.float32)
bx = x.astype(bfloat16)
y = np.array([82432, 7, 0], dtype=np.float32)
by = y.astype(bfloat16)
np.testing.assert_equal(x == y, bx == by)
np.testing.assert_equal(x != y, bx != by)
np.testing.assert_equal(x < y, bx < by)
np.testing.assert_equal(x > y, bx > by)
np.testing.assert_equal(x <= y, bx <= by)
np.testing.assert_equal(x >= y, bx >= by)
def testEqual2(self):
a = np.array([401408], bfloat16)
b = np.array([82432], bfloat16)
self.assertFalse(a.__eq__(b))
def testCasts(self):
for dtype in [
np.float16, np.float32, np.float64, np.int8, np.int16, np.int32,
np.int64, np.complex64, np.complex128, np.uint8, np.uint16, np.uint32,
np.uint64, np.intc, np.int_, np.longlong, np.uintc, np.ulonglong
]:
x = np.array([[1, 2, 3]], dtype=dtype)
y = x.astype(bfloat16)
z = y.astype(dtype)
self.assertTrue(np.all(x == y))
self.assertEqual(bfloat16, y.dtype)
self.assertTrue(np.all(x == z))
self.assertEqual(dtype, z.dtype)
def testConformNumpyComplex(self):
for dtype in [np.complex64, np.complex128]:
x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
y_np = x.astype(np.float32)
y_tf = x.astype(bfloat16)
numpy_assert_allclose(y_np, y_tf, atol=2e-2)
z_np = y_np.astype(dtype)
z_tf = y_tf.astype(dtype)
numpy_assert_allclose(z_np, z_tf, atol=2e-2)
def testArange(self):
np.testing.assert_equal(
np.arange(100, dtype=np.float32).astype(bfloat16),
np.arange(100, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
np.arange(-0., -7., -0.25, dtype=bfloat16))
np.testing.assert_equal(
np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
np.arange(-16384., 16384., 64., dtype=bfloat16))
# pylint: disable=g-complex-comprehension
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in UNARY_UFUNCS))
def testUnaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_UFUNCS))
def testBinaryUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7, 10).astype(bfloat16)
y = rng.randn(4, 1, 7, 10).astype(bfloat16)
numpy_assert_allclose(
op(x, y).astype(np.float32),
op(x.astype(np.float32), y.astype(np.float32)),
rtol=1e-2)
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in BINARY_PREDICATE_UFUNCS))
def testBinaryPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
np.testing.assert_equal(
op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
@parameterized.named_parameters(({
"testcase_name": "_" + op.__name__,
"op": op
} for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
def testPredicateUfunc(self, op):
rng = np.random.RandomState(seed=42)
shape = (3, 7, 10)
posinf_flips = rng.rand(*shape) < 0.1
neginf_flips = rng.rand(*shape) < 0.1
nan_flips = rng.rand(*shape) < 0.1
vals = rng.randn(*shape)
vals = np.where(posinf_flips, np.inf, vals)
vals = np.where(neginf_flips, -np.inf, vals)
vals = np.where(nan_flips, np.nan, vals)
vals = vals.astype(bfloat16)
np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
def testDivmod(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randn(4, 1, 7).astype(bfloat16)
o1, o2 = np.divmod(x, y)
e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
numpy_assert_allclose(o1, e1, rtol=1e-2)
numpy_assert_allclose(o2, e2, rtol=1e-2)
def testModf(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
o1, o2 = np.modf(x)
e1, e2 = np.modf(x.astype(np.float32))
numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
def testLdexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
y = rng.randint(-50, 50, (1, 7))
numpy_assert_allclose(
np.ldexp(x, y).astype(np.float32),
np.ldexp(x.astype(np.float32), y),
rtol=1e-2,
atol=1e-6)
def testFrexp(self):
rng = np.random.RandomState(seed=42)
x = rng.randn(3, 7).astype(bfloat16)
mant1, exp1 = np.frexp(x)
mant2, exp2 = np.frexp(x.astype(np.float32))
np.testing.assert_equal(exp1, exp2)
numpy_assert_allclose(mant1, mant2, rtol=1e-2)
def testNextAfter(self):
one = np.array(1., dtype=bfloat16)
two = np.array(2., dtype=bfloat16)
zero = np.array(0., dtype=bfloat16)
nan = np.array(np.nan, dtype=bfloat16)
np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
np.testing.assert_equal(np.nextafter(one, one), one)
smallest_denormal = float.fromhex("1.0p-133")
np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
for a, b in itertools.permutations([0., -0., nan], 2):
np.testing.assert_equal(
np.nextafter(
np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
np.nextafter(
np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
if __name__ == "__main__":
absltest.main()

View File

@ -214,11 +214,9 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
}
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
const se::Platform* platform =
device.local_device_state()->executor()->platform();
if (platform->id() == se::host::kHostPlatformId) {
if (device.client()->platform_id() == kCpuId) {
return kDLCPU;
} else if (platform->id() == se::cuda::kCudaPlatformId) {
} else if (device.client()->platform_id() == kGpuId) {
return kDLGPU;
}
return InvalidArgument("Device %s cannot be used as a DLPack device.",
@ -228,7 +226,7 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
DLContext context;
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
context.device_id = device.local_device_id();
context.device_id = device.local_hardware_id();
return context;
}
@ -241,14 +239,14 @@ StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
"DLPack CPU device type mismatch with PjRtClient platform %s",
client.platform_name());
}
return client.LookupLocalDevice(context.device_id);
return client.LookupAddressableDevice(context.device_id);
case kDLGPU:
if (client.platform_id() != kGpuId) {
return InvalidArgument(
"DLPack GPU device type mismatch with PjRtClient platform %s",
client.platform_name());
}
return client.LookupLocalDevice(context.device_id);
return client.LookupAddressableDevice(context.device_id);
default:
return InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type);
@ -297,7 +295,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
pack->tensor.manager_ctx = pack.get();
pack->tensor.deleter = DLPackTensorDeleter;
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
dt.ctx.device_id = buffer->buffer()->device()->local_device_id();
dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
TF_ASSIGN_OR_RETURN(dt.dtype,
PrimitiveTypeToDLDataType(

View File

@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
callback_ = callback;
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
for (const auto& client : clients) {
for (const auto& device : client->devices()) {
devices_.push_back(device.get());
for (auto device : client->devices()) {
devices_.push_back(device);
}
}
CHECK_GT(devices_.size(), 0);
@ -342,11 +342,7 @@ StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
const PjRtDevice* device, const Shape& shape) {
std::shared_ptr<Literal> literal_shared;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
TF_ASSIGN_OR_RETURN(Literal literal,
local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal()));
TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
return absl::make_unique<Literal>(std::move(literal));
}

View File

@ -86,8 +86,8 @@ StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
}
StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
if (buffer_->device()->local_device_state()->executor()->platform_kind() !=
se::PlatformKind::kCuda) {
// TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
if (buffer_->client()->platform_id() != kGpuId) {
return InvalidArgument(
"__cuda_array_interface__ is only defined for NVidia GPU buffers.");
}

View File

@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
std::vector<ClientAndPtr<PjRtDevice>> devices;
devices.reserve(pjrt_client_->devices().size());
for (const auto& device : pjrt_client_->devices()) {
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
auto span = pjrt_client_->devices();
devices.reserve(span.size());
for (PjRtDevice* device : span) {
devices.push_back(WrapWithClient(shared_from_this(), device));
}
return devices;
}
@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
result[r].resize(num_partitions);
for (int p = 0; p < num_partitions; ++p) {
int device_id = device_assignment(r, p);
auto iter = pjrt_client_->id_to_device().find(device_id);
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
pjrt_client_->LookupDevice(device_id));
result[r][p] = WrapWithClient(shared_from_this(), device);
}
}
return result;
@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
std::vector<ClientAndPtr<PjRtDevice>> result;
for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0);
auto iter = pjrt_client_->id_to_device().find(device_id);
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
result.push_back(WrapWithClient(shared_from_this(), iter->second));
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
pjrt_client_->LookupDevice(device_id));
result.push_back(WrapWithClient(shared_from_this(), device));
}
return result;
}
@ -95,8 +96,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
device = pjrt_client_->local_devices().front();
}
CHECK(device != nullptr);
auto iter = pjrt_client_->id_to_device().find(device->id());
if (iter->second != device) {
TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
pjrt_client_->LookupDevice(device->id()));
if (found_device != device) {
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
device->DebugString(),
pjrt_client_->platform_name());

View File

@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
const std::string& platform_name() const {
return pjrt_client_->platform_name();
}
int local_device_count() const { return pjrt_client_->local_device_count(); }
int addressable_device_count() const {
return pjrt_client_->addressable_device_count();
}
int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); }

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla {
@ -106,59 +107,66 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
}
}
void PyTreeDef::FlattenInto(py::handle handle,
std::vector<py::object>& leaves) {
void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
absl::optional<py::function> leaf_predicate) {
Node node;
int start_num_nodes = traversal_.size();
int start_num_leaves = leaves.size();
node.kind = GetKind(handle, &node.custom);
if (node.kind == Kind::kNone) {
// Nothing to do.
} else if (node.kind == Kind::kTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
for (py::handle entry : tuple) {
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kList) {
py::list list = py::reinterpret_borrow<py::list>(handle);
node.arity = list.size();
for (py::handle entry : list) {
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kDict) {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
if (PyList_Sort(keys.ptr())) {
throw std::runtime_error("Dictionary key sort failed.");
}
for (py::handle key : keys) {
FlattenInto(dict[key], leaves);
}
node.arity = dict.size();
node.node_data = std::move(keys);
} else if (node.kind == Kind::kCustom) {
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
if (out.size() != 2) {
throw std::runtime_error(
"PyTree custom to_iterable function should return a pair");
}
node.node_data = out[1];
node.arity = 0;
for (py::handle entry : py::cast<py::iterable>(out[0])) {
++node.arity;
FlattenInto(entry, leaves);
}
} else if (node.kind == Kind::kNamedTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
for (py::handle entry : tuple) {
FlattenInto(entry, leaves);
}
if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
} else {
assert(node.kind == Kind::kLeaf);
leaves.push_back(pybind11::reinterpret_borrow<py::object>(handle));
node.kind = GetKind(handle, &node.custom);
auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
FlattenInto(child, leaves, leaf_predicate);
};
if (node.kind == Kind::kNone) {
// Nothing to do.
} else if (node.kind == Kind::kTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
for (py::handle entry : tuple) {
recurse(entry);
}
} else if (node.kind == Kind::kList) {
py::list list = py::reinterpret_borrow<py::list>(handle);
node.arity = list.size();
for (py::handle entry : list) {
recurse(entry);
}
} else if (node.kind == Kind::kDict) {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
if (PyList_Sort(keys.ptr())) {
throw std::runtime_error("Dictionary key sort failed.");
}
for (py::handle key : keys) {
recurse(dict[key]);
}
node.arity = dict.size();
node.node_data = std::move(keys);
} else if (node.kind == Kind::kCustom) {
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
if (out.size() != 2) {
throw std::runtime_error(
"PyTree custom to_iterable function should return a pair");
}
node.node_data = out[1];
node.arity = 0;
for (py::handle entry : py::cast<py::iterable>(out[0])) {
++node.arity;
recurse(entry);
}
} else if (node.kind == Kind::kNamedTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.arity = tuple.size();
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
for (py::handle entry : tuple) {
recurse(entry);
}
} else {
assert(node.kind == Kind::kLeaf);
leaves.push_back(py::reinterpret_borrow<py::object>(handle));
}
}
node.num_nodes = traversal_.size() - start_num_nodes + 1;
node.num_leaves = leaves.size() - start_num_leaves;
@ -166,10 +174,10 @@ void PyTreeDef::FlattenInto(py::handle handle,
}
/*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
PyTreeDef::Flatten(py::handle x) {
PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
std::vector<py::object> leaves;
auto tree = absl::make_unique<PyTreeDef>();
tree->FlattenInto(x, leaves);
tree->FlattenInto(x, leaves, leaf_predicate);
return std::make_pair(std::move(leaves), std::move(tree));
}
@ -618,7 +626,8 @@ std::string PyTreeDef::ToString() const {
void BuildPytreeSubmodule(py::module& m) {
py::module pytree = m.def_submodule("pytree", "Python tree library");
pytree.def("flatten", &PyTreeDef::Flatten);
pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
py::arg("leaf_predicate") = absl::nullopt);
pytree.def("tuple", &PyTreeDef::Tuple);
pytree.def("all_leaves", &PyTreeDef::AllLeaves);

View File

@ -85,11 +85,13 @@ class PyTreeDef {
// Flattens a Pytree into a list of leaves and a PyTreeDef.
static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
Flatten(pybind11::handle x);
Flatten(pybind11::handle x,
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
// Recursive helper used to implement Flatten().
void FlattenInto(pybind11::handle handle,
std::vector<pybind11::object>& leaves);
void FlattenInto(
pybind11::handle handle, std::vector<pybind11::object>& leaves,
absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
// Tests whether the given list is a flat list of leaves.
static bool AllLeaves(const pybind11::iterable& x);

View File

@ -37,6 +37,7 @@ cc_library(
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core/framework:allocator",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:env",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/memory",

View File

@ -37,8 +37,8 @@ namespace xla {
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip)
: xla::PjRtDevice(id, /*local_device_state=*/nullptr,
/*device_kind=*/"Cloud TPU", host_id),
: xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr,
/*device_kind=*/"Cloud TPU", host_id),
coords_(coords),
core_on_chip_(core_on_chip) {}
@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable(
<< "Inserting duplicate replica:" << replica;
executables_[replica] =
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
addressable_device_logical_ids_.emplace_back(replica, partition);
local_logical_device_ids_.emplace_back(replica, partition);
local_devices_.push_back(device);
}
}
@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices(
// long time and we want all cores to be scheduled in parallel.
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
&execute_semaphore]() {
const int replica = addressable_device_logical_ids_[i].first;
const int partition = addressable_device_logical_ids_[i].second;
const int replica = local_logical_device_ids_[i].first;
const int partition = local_logical_device_ids_[i].second;
RunId run_id;
auto result = ExecuteHelper(argument_handles, argument_handles[i],
replica, partition, run_id);

View File

@ -32,13 +32,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/threadpool.h"
namespace xla {
constexpr char kTpuPlatform[] = "tpu";
class TpuDevice : public PjRtDevice {
class TpuDevice : public PjRtStreamExecutorDevice {
public:
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip);
@ -298,9 +299,8 @@ class PyTpuExecutable {
return device_assignment_;
}
const std::vector<std::pair<int, int>>& addressable_device_logical_ids()
const {
return addressable_device_logical_ids_;
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
return local_logical_device_ids_;
}
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
@ -341,14 +341,16 @@ class PyTpuExecutable {
// The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
std::vector<std::pair<int, int>> addressable_device_logical_ids_;
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
// on multi-host platforms.
// If there are 4 replicas and 2 partitions on a single host platform, size of
// local_logical_device_ids_ is 4*2 = 8.
std::vector<std::pair<int, int>> local_logical_device_ids_;
// local_devices_[i] is the Device to which addressable_device_logical_ids_[i]
// is assigned. shared_ptrs instead of unique_ptrs to play well with the
// Python bindings (see xla.cc).
// local_devices_[i] is the Device to which local_logical_device_ids_[i] is
// assigned.
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
// (see xla.cc).
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
xla::Shape result_shape_;

View File

@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
py::class_<PyTpuExecutable>(m, "TpuExecutable")
.def("local_logical_device_ids",
&PyTpuExecutable::addressable_device_logical_ids)
&PyTpuExecutable::local_logical_device_ids)
.def("local_devices", &PyTpuExecutable::local_devices)
.def_property_readonly("client", &PyTpuExecutable::client)
.def("size_of_generated_code_in_bytes",

View File

@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/types.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/python/lib/core/bfloat16.h"
namespace xla {
@ -81,8 +81,8 @@ xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
case U64:
return py::dtype::of<uint64>();
case BF16: {
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
return py::dtype::from_args(bfloat16);
py::handle bfloat16(tensorflow::Bfloat16Dtype());
return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
}
case F16:
return py::dtype("e"); // PEP 3118 code for "float16
@ -237,10 +237,11 @@ StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
// We requested an array of uint16 since NumPy doesn't know how
// to produce our custom bfloat16 type. Reinterpret the array as bfloat16
// before handing it back to the caller.
TF_ASSIGN_OR_RETURN(py::object bfloat16, Bfloat16Dtype());
py::handle bfloat16(tensorflow::Bfloat16Dtype());
bfloat16.inc_ref();
array = py::reinterpret_steal<py::array>(
PyArray_View(reinterpret_cast<PyArrayObject*>(array.ptr()),
reinterpret_cast<PyArray_Descr*>(bfloat16.release().ptr()),
reinterpret_cast<PyArray_Descr*>(bfloat16.ptr()),
static_cast<PyTypeObject*>(nullptr)));
}
return array;

View File

@ -40,7 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/python/dlpack.h"
#include "tensorflow/compiler/xla/python/jax_jit.h"
#include "tensorflow/compiler/xla/python/ops.h"
@ -59,6 +58,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
@ -110,6 +110,8 @@ PYBIND11_MODULE(xla_extension, m) {
throw std::runtime_error("Unable to initialize Numpy API");
}
CHECK(tensorflow::RegisterNumpyBfloat16());
// Types
py::enum_<PrimitiveType>(m, "PrimitiveType")
.value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
@ -132,7 +134,8 @@ PYBIND11_MODULE(xla_extension, m) {
.value("OPAQUE_TYPE", OPAQUE_TYPE)
.value("TOKEN", TOKEN);
m.def("bfloat16_dtype", Bfloat16Dtype);
m.def("bfloat16_dtype",
[]() { return py::handle(tensorflow::Bfloat16Dtype()); });
// Must be before PyClient.compile.
BuildXlaCompilerSubmodule(m);
@ -149,7 +152,10 @@ PYBIND11_MODULE(xla_extension, m) {
.def_property_readonly("host_id", &PjRtDevice::host_id,
"Integer ID of this device's host.\n\n"
"This is always 0 except on multi-host platforms.")
.def_property_readonly("platform", &PjRtDevice::platform_name)
.def_property_readonly("platform",
[](const PjRtDevice& device) {
return device.client()->platform_name();
})
.def_property_readonly("device_kind", &PjRtDevice::device_kind)
.def_property_readonly(
"client",
@ -234,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def("device_count", &PyClient::device_count)
.def("local_device_count", &PyClient::local_device_count)
.def("local_device_count", &PyClient::addressable_device_count)
.def("devices", &PyClient::Devices)
.def("local_devices", &PyClient::LocalDevices)
.def("host_id", &PyClient::host_id)
@ -381,10 +387,10 @@ PYBIND11_MODULE(xla_extension, m) {
[](PyExecutable* exec) {
auto span = exec->addressable_device_logical_ids();
// Not on dispatch critical path, so ok to have heap allocation.
std::vector<std::pair<int, int>> addressable_device_logical_ids;
addressable_device_logical_ids.reserve(span.size());
std::vector<std::pair<int, int>> addressable_device_logic_ids;
addressable_device_logic_ids.reserve(span.size());
for (const auto& logical_device_id : span) {
addressable_device_logical_ids.push_back(std::make_pair(
addressable_device_logic_ids.push_back(std::make_pair(
logical_device_id.replica, logical_device_id.partition));
}
})

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/threadpool.h"
namespace xla {
@ -158,6 +159,18 @@ class AotCompilationMetadata {
// platform.
class Compiler {
public:
struct CompileOptions {
// If device_allocator is not null, the compiler may use it to allocate temp
// space on the device for use during compilation. For example, the
// compiler may allocate buffers on the device and then run variants of a
// given algorithm over those buffers, to see which variant is fastest. Any
// space allocated will be deallocated before the compilation returns.
se::DeviceMemoryAllocator* device_allocator = nullptr;
// An optional thread pool for parallel compilation.
tensorflow::thread::ThreadPool* thread_pool = nullptr;
};
virtual ~Compiler() {}
// Returns the ID of the platform that this compiler targets.
@ -165,31 +178,24 @@ class Compiler {
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
// module.
//
// If device_allocator is not null, the compiler may use it to allocate temp
// space on the device for use during compilation. For example, the compiler
// may allocate buffers on the device and then run variants of a given
// algorithm over those buffers, to see which variant is fastest. Any space
// allocated should be deallocated before this function returns.
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) = 0;
const CompileOptions& options) = 0;
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
return RunHloPasses(std::move(module), executor,
CompileOptions{device_allocator});
}
// Runs HLO passes to optimize the given HloModule, perform scheduling and
// buffer assignment, returns the optimized module and the buffer assignments.
// This interface is intentionally narrow.
//
// If device_allocator is not null, the compiler may use it to allocate temp
// space on the device for use during compilation. For example, the compiler
// may allocate buffers on the device and then run variants of a given
// algorithm over those buffers, to see which variant is fastest. Any space
// allocated should be deallocated before this function returns.
virtual StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator,
bool optimize) {
se::StreamExecutor* executor, bool optimize,
const CompileOptions& options) {
return Unimplemented("This compiler does not support this method");
}
@ -201,24 +207,33 @@ class Compiler {
//
// The compiler may optionally specialize to the individual device
// (not just type of device) indicated by the executor.
//
// device_allocator is optional; see RunHloPasses.
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) = 0;
const CompileOptions& options) = 0;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
return RunBackend(std::move(module), executor,
CompileOptions{device_allocator});
}
// Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules, and returns a corresponding
// sequence of executable objects.
//
// device_allocator is optional; see RunHloPasses.
//
// TODO(b/68666782): Remove this method after adding support for multiple
// modules to RunHloPasses and RunBackends.
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) = 0;
const CompileOptions& options) = 0;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
return Compile(std::move(module_group), stream_exec,
CompileOptions{device_allocator});
}
// Returns the backend configurations that the backend will consider for the
// given HLO. Returns no configurations if the backend does not support

View File

@ -553,7 +553,7 @@ Status CreateHloProfilingArtifacts(
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
se::DeviceMemoryAllocator* /*device_allocator*/) {
const CompileOptions& /*options*/) {
std::unique_ptr<llvm::TargetMachine> jit_target_machine =
SimpleOrcJIT::InferTargetMachineForJIT(
CompilerTargetOptions(module->config()),
@ -566,12 +566,13 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
CpuCompiler::RunHloPassesAndBufferAssignement(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
se::StreamExecutor* executor,
bool optimize,
const CompileOptions& options) {
if (optimize) {
TF_ASSIGN_OR_RETURN(
module, RunHloPasses(std::move(module), executor, device_allocator));
TF_ASSIGN_OR_RETURN(module,
RunHloPasses(std::move(module), executor, options));
}
// Select an order for emitting the HLO instructions for each computation.
@ -632,7 +633,7 @@ struct OrcJITPostCompilationHook {
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* /*device_allocator*/) {
const CompileOptions& options) {
VLOG(1) << "Compiling: " << module->name();
XLA_SCOPED_LOGGING_TIMER(
absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));

View File

@ -134,18 +134,17 @@ class CpuCompiler : public LLVMCompiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator,
bool optimize) override;
se::StreamExecutor* executor, bool optimize,
const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,

View File

@ -440,7 +440,7 @@ filegroup(
name = "nccl_collective_thunk_src",
srcs = if_nccl(
["nccl_collective_thunk.cc"],
["dummy_collective_thunk.cc"],
["nccl_collective_thunk_dummy.cc"],
),
)
@ -448,7 +448,7 @@ tf_cuda_library(
name = "nccl_collective_thunk",
srcs = if_cuda_or_rocm(
[":nccl_collective_thunk_src"],
["dummy_collective_thunk.cc"],
["nccl_collective_thunk_dummy.cc"],
),
hdrs = ["nccl_collective_thunk.h"],
deps = [
@ -480,7 +480,7 @@ filegroup(
name = "nccl_all_gather_thunk_src",
srcs = if_nccl(
["nccl_all_gather_thunk.cc"],
["dummy_all_gather_thunk.cc"],
["nccl_all_gather_thunk_dummy.cc"],
),
)
@ -488,7 +488,7 @@ tf_cuda_library(
name = "nccl_all_gather_thunk",
srcs = if_cuda_or_rocm(
[":nccl_all_gather_thunk_src"],
["dummy_all_gather_thunk.cc"],
["nccl_all_gather_thunk_dummy.cc"],
),
hdrs = ["nccl_all_gather_thunk.h"],
deps = [
@ -520,7 +520,7 @@ filegroup(
name = "nccl_all_reduce_thunk_src",
srcs = if_nccl(
["nccl_all_reduce_thunk.cc"],
["dummy_all_reduce_thunk.cc"],
["nccl_all_reduce_thunk_dummy.cc"],
),
)
@ -528,7 +528,7 @@ tf_cuda_library(
name = "nccl_all_reduce_thunk",
srcs = if_cuda_or_rocm(
[":nccl_all_reduce_thunk_src"],
["dummy_all_reduce_thunk.cc"],
["nccl_all_reduce_thunk_dummy.cc"],
),
hdrs = ["nccl_all_reduce_thunk.h"],
deps = [
@ -560,7 +560,7 @@ filegroup(
name = "nccl_all_to_all_thunk_src",
srcs = if_nccl(
["nccl_all_to_all_thunk.cc"],
["dummy_all_to_all_thunk.cc"],
["nccl_all_to_all_thunk_dummy.cc"],
),
)
@ -568,7 +568,7 @@ tf_cuda_library(
name = "nccl_all_to_all_thunk",
srcs = if_cuda_or_rocm(
[":nccl_all_to_all_thunk_src"],
["dummy_all_to_all_thunk.cc"],
["nccl_all_to_all_thunk_dummy.cc"],
),
hdrs = ["nccl_all_to_all_thunk.h"],
deps = [
@ -600,7 +600,7 @@ filegroup(
name = "nccl_test_utils_src",
srcs = if_nccl(
["nccl_test_utils.cc"],
["dummy_nccl_test_utils.cc"],
["nccl_test_utils_dummy.cc"],
),
)
@ -608,7 +608,7 @@ tf_cuda_library(
name = "nccl_test_utils",
srcs = if_cuda_or_rocm(
[":nccl_test_utils_src"],
["dummy_nccl_test_utils.cc"],
["nccl_test_utils_dummy.cc"],
),
hdrs = ["nccl_test_utils.h"],
deps = [
@ -1452,7 +1452,11 @@ cc_library(
"//tensorflow/stream_executor:stream_executor_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:TransformUtils",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
],
@ -1517,7 +1521,7 @@ cc_library(
"//tensorflow/stream_executor:stream_executor_headers",
"//tensorflow/stream_executor/cuda:cuda_diagnostics",
"//tensorflow/stream_executor/gpu:asm_compiler",
]),
]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
)
cc_library(

View File

@ -108,12 +108,17 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
llvm::Module* llvm_module,
GpuVersion gpu_version,
se::StreamExecutor* stream_exec) {
se::StreamExecutor* stream_exec,
bool relocatable) {
if (rocdl_dir_.empty()) {
// Compute rocdl_dir_ just once and cache it in this member.
rocdl_dir_ = GetROCDLDir(module->config());
}
if (relocatable) {
return Unimplemented("relocatable target binary is not implemented");
}
std::vector<uint8> hsaco;
{
XLA_SCOPED_LOGGING_TIMER(

View File

@ -41,7 +41,8 @@ class AMDGPUCompiler : public GpuCompiler {
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
bool relocatable) override;
private:
// The parent directory of ROCm-Device-Libs IR libraries.

View File

@ -24,11 +24,15 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Utils/SplitModule.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "tensorflow/compiler/xla/protobuf_util.h"
@ -114,11 +118,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/subprocess.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/env_var.h"
@ -470,14 +476,14 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
const CompileOptions& options) {
// We dump the post-optimization HLO in RunBackend so no need to dump it here.
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
tensorflow::profiler::TraceMe activity(
[&] { return absl::StrCat("HLO Transforms:", module->name()); },
tensorflow::profiler::TraceMeLevel::kInfo);
TF_RETURN_IF_ERROR(
OptimizeHloModule(module.get(), stream_exec, device_allocator));
OptimizeHloModule(module.get(), stream_exec, options.device_allocator));
TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
@ -494,10 +500,10 @@ StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
GpuCompiler::RunHloPassesAndBufferAssignement(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
bool optimize, const CompileOptions& options) {
if (optimize) {
TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module),
executor, device_allocator));
TF_ASSIGN_OR_RETURN(hlo_module,
RunHloPasses(std::move(hlo_module), executor, options));
}
std::unique_ptr<StreamAssignment> stream_assignment =
@ -641,24 +647,133 @@ static Status CompileModuleToLlvmIrImpl(
return Status::OK();
}
StatusOr<std::pair<std::string, std::vector<uint8>>>
GpuCompiler::CompileToTargetBinary(const HloModule& module,
std::unique_ptr<llvm::Module> llvm_module,
se::StreamExecutor* stream_exec,
const CompileOptions& options) {
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
const auto compile_single_module =
[this, stream_exec, &module](
llvm::Module* llvm_module,
bool relocatable) -> StatusOr<BackendCompileResult> {
{
XLA_SCOPED_LOGGING_TIMER(
"GpuCompiler::RunBackend - Running LLVM verifier");
std::string err;
llvm::raw_string_ostream err_stream(err);
// verifyModule() returns true if the module is broken.
TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
<< "Invalid LLVM IR before optimizations:\n"
<< err_stream.str()
<< "\nThis probably indicates a bug in the HLO -> LLVM IR "
"lowering. "
"Rerun with --xla_dump_to to get the IR and looks for files "
"with "
"name containing: *"
<< FilenameFor(module, "", "") << "*";
}
GpuVersion gpu_version = GetGpuVersion(stream_exec);
return CompileTargetBinary(&module, llvm_module, gpu_version, stream_exec,
relocatable);
};
tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
if (!thread_pool) {
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
}
// Test whether LinkModules is supported.
if (this->LinkModules(stream_exec, {}).status().code() ==
tensorflow::error::Code::UNIMPLEMENTED) {
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
}
std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
int num_functions = 0;
for (llvm::Function& func : llvm_module->functions()) {
if (!func.isDeclaration() &&
func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
num_functions++;
}
}
llvm::SplitModule(
std::move(llvm_module),
std::max<unsigned>(
1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
[&](std::unique_ptr<llvm::Module> module) {
llvm_modules.push_back(std::move(module));
},
/*PreserveLocals=*/true);
std::vector<StatusOr<BackendCompileResult>> compile_results(
llvm_modules.size());
tensorflow::BlockingCounter counter(llvm_modules.size());
for (int i = 0; i < llvm_modules.size(); i++) {
thread_pool->Schedule([&compile_results, compile_single_module, i,
&llvm_modules, &counter] {
llvm::Module* original_module = llvm_modules[i].get();
llvm::LLVMContext context;
std::string buffer;
llvm::raw_string_ostream error(buffer);
llvm::DiagnosticPrinterRawOStream printer(error);
auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
void* Context) {
auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
diag_info.print(*printer);
};
context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
std::unique_ptr<llvm::Module> new_llvm_module;
{
std::string ir;
{
llvm::raw_string_ostream os(ir);
original_module->print(os, nullptr);
}
llvm::SMDiagnostic err;
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
}
compile_results[i] =
compile_single_module(new_llvm_module.get(), /*relocatable=*/true);
counter.DecrementCount();
});
}
counter.Wait();
std::string ptx_snippets;
std::vector<std::vector<uint8>> submodule_compile_results;
for (auto& maybe_result : compile_results) {
TF_ASSIGN_OR_RETURN(auto result, maybe_result);
if (result.second.empty()) {
continue;
}
ptx_snippets += result.first;
ptx_snippets += "\n";
submodule_compile_results.push_back(result.second);
}
TF_ASSIGN_OR_RETURN(
std::vector<uint8> backend_result,
this->LinkModules(stream_exec, std::move(submodule_compile_results)));
return std::make_pair(ptx_snippets, backend_result);
}
StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
const CompileOptions& options) {
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend");
auto slow_compile_alarm = SlowCompilationAlarm();
TF_RET_CHECK(stream_exec != nullptr);
llvm::LLVMContext llvm_context;
std::string buffer;
llvm::raw_string_ostream error(buffer);
llvm::DiagnosticPrinterRawOStream printer(error);
auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
void* Context) {
auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
diag_info.print(*printer);
};
llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
GpuDeviceInfo gpu_device_info;
gpu_device_info.threads_per_block_limit =
@ -724,34 +839,16 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
std::string err;
llvm::raw_string_ostream err_stream(err);
// verifyModule() returns true if the module is broken.
TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
<< "Invalid LLVM IR before optimizations:\n"
<< err_stream.str()
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
"Rerun with --xla_dump_to to get the IR and looks for files with "
"name containing: *"
<< FilenameFor(*module, "", "") << "*";
}
GpuVersion gpu_version = GetGpuVersion(stream_exec);
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
CompileTargetBinary(module.get(), llvm_module.get(),
gpu_version, stream_exec));
CompileToTargetBinary(*module, std::move(llvm_module),
stream_exec, options));
if (DumpingEnabledForHloModule(*module)) {
DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
thunk_schedule->ToString());
}
GpuVersion gpu_version = GetGpuVersion(stream_exec);
auto* gpu_executable = new GpuExecutable(
backend_result.first, backend_result.second, gpu_version,
std::move(thunk_schedule), std::move(module),

View File

@ -53,14 +53,13 @@ class GpuCompiler : public LLVMCompiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator,
bool optimize) override;
se::StreamExecutor* executor, bool optimize,
const CompileOptions& options) override;
Status OptimizeHloModule(HloModule* hlo_module,
se::StreamExecutor* stream_exec,
@ -84,19 +83,23 @@ class GpuCompiler : public LLVMCompiler {
virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version,
se::StreamExecutor* stream_exec) = 0;
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
bool relocatable) = 0;
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
AotCompilationOptions const& options) override;
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileToTargetBinary(
const HloModule& module, std::unique_ptr<llvm::Module> llvm_module,
se::StreamExecutor* stream_exec, const CompileOptions& options);
se::Platform::Id PlatformId() const override { return platform_id_; }
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
@ -116,6 +119,12 @@ class GpuCompiler : public LLVMCompiler {
}
private:
virtual StatusOr<std::vector<uint8>> LinkModules(
se::StreamExecutor* stream_exec,
std::vector<std::vector<uint8>> modules) {
return Unimplemented("LinkModules is not implemented.");
}
se::Platform::Id platform_id_;
// The triple that represents our target.

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
namespace xla {
namespace gpu {
@ -299,7 +300,8 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
NVPTXCompiler::CompileTargetBinary(const HloModule* module,
llvm::Module* llvm_module,
GpuVersion gpu_version,
se::StreamExecutor* stream_exec) {
se::StreamExecutor* stream_exec,
bool relocatable) {
std::pair<int, int> compute_capability =
absl::get<std::pair<int, int>>(gpu_version);
@ -338,7 +340,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
stream_exec, ptx, compute_capability.first, compute_capability.second,
module->config());
module->config(), relocatable);
return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
std::move(cubin));
@ -346,7 +348,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
int cc_minor, const HloModuleConfig& hlo_module_config) {
int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable) {
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
tensorflow::profiler::TraceMe activity(
"PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
@ -361,7 +363,7 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
tensorflow::mutex_lock lock(mutex_);
std::tie(iter, inserted) = compilation_cache_.emplace(
std::piecewise_construct,
std::forward_as_tuple(ptx, cc_major, cc_minor),
std::forward_as_tuple(ptx, cc_major, cc_minor, relocatable),
std::forward_as_tuple());
cache_ptx = &iter->first.ptx;
cache_value = &iter->second;
@ -375,9 +377,13 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
if (inserted) {
CHECK(!cache_value->compilation_done);
if (!ptx.empty()) {
StatusOr<std::vector<uint8>> maybe_cubin =
se::CompileGpuAsm(stream_exec->device_ordinal(), cache_ptx->c_str(),
PtxOptsFromConfig(hlo_module_config));
auto ptxas_config = PtxOptsFromConfig(hlo_module_config);
if (relocatable) {
ptxas_config.extra_flags.push_back("-c");
}
StatusOr<std::vector<uint8>> maybe_cubin = se::CompileGpuAsm(
stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
if (maybe_cubin.ok()) {
cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
VLOG(2) << "Compiled PTX size:" << ptx.size()
@ -445,5 +451,17 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
return cache_value->cubin_data;
}
StatusOr<std::vector<uint8>> NVPTXCompiler::LinkModules(
se::StreamExecutor* stream_exec, std::vector<std::vector<uint8>> modules) {
std::vector<stream_executor::CubinOrPTXImage> images;
images.reserve(modules.size());
for (auto& module : modules) {
images.push_back({"", std::move(module)});
}
return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
stream_exec->implementation()->GpuContextHack()),
images);
}
} // namespace gpu
} // namespace xla

View File

@ -52,9 +52,14 @@ class NVPTXCompiler : public GpuCompiler {
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
bool relocatable) override;
private:
StatusOr<std::vector<uint8>> LinkModules(
se::StreamExecutor* stream_exec,
std::vector<std::vector<uint8>> modules) override;
tensorflow::mutex mutex_;
// When compiling an HLO module, we need to find a path to the nvvm libdevice
@ -71,7 +76,7 @@ class NVPTXCompiler : public GpuCompiler {
// compiled cubin. If compilation was unsuccessful, returns an empty vector.
std::vector<uint8> CompileGpuAsmOrGetCachedResult(
se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
int cc_minor, const HloModuleConfig& hlo_module_config);
int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable);
// The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
// -> cubin so we don't recompile the same ptx twice. This is important for
@ -86,24 +91,32 @@ class NVPTXCompiler : public GpuCompiler {
// If compiling the ptx fails, we return an empty cubin, cross our fingers,
// and leave compilation up to the driver.
struct CompilationCacheKey {
CompilationCacheKey(std::string ptx, int cc_major, int cc_minor)
: ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {}
CompilationCacheKey(std::string ptx, int cc_major, int cc_minor,
bool relocatable)
: ptx(std::move(ptx)),
cc_major(cc_major),
cc_minor(cc_minor),
relocatable(relocatable) {}
string ptx;
int cc_major;
int cc_minor;
bool relocatable;
};
struct CompilationCacheHash {
size_t operator()(const CompilationCacheKey& key) const {
return tensorflow::Hash64Combine(
tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major),
key.cc_minor);
tensorflow::Hash64Combine(
tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx),
key.cc_major),
key.cc_minor),
key.relocatable);
}
};
struct CompilationCacheEq {
size_t operator()(const CompilationCacheKey& a,
const CompilationCacheKey& b) const {
return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
a.ptx == b.ptx;
a.ptx == b.ptx && a.relocatable == b.relocatable;
}
};
struct CompilationCacheValue {

View File

@ -95,7 +95,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
se::DeviceMemoryAllocator* /*device_allocator*/) {
const CompileOptions& /*options*/) {
VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
return std::move(hlo_module);
@ -103,7 +103,7 @@ StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* /*device_allocator*/) {
const CompileOptions& /*options*/) {
TF_RET_CHECK(stream_exec != nullptr);
VLOG(1) << "Run backend " << hlo_module->name();
@ -128,7 +128,7 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
const CompileOptions& options) {
if (module_group->empty()) {
return std::vector<std::unique_ptr<Executable>>();
}
@ -141,12 +141,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
"Unexpected number of StreamExecutor's.");
}
auto hlo_modules = module_group->ConsumeModules();
TF_ASSIGN_OR_RETURN(auto module,
RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0],
device_allocator));
TF_ASSIGN_OR_RETURN(
auto executable,
RunBackend(std::move(module), stream_exec[0][0], device_allocator));
TF_ASSIGN_OR_RETURN(auto module, RunHloPasses(std::move(hlo_modules[0]),
stream_exec[0][0], options));
TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(module),
stream_exec[0][0], options));
std::vector<std::unique_ptr<Executable>> ret;
ret.push_back(std::move(executable));
return std::move(ret);

View File

@ -45,14 +45,14 @@ class InterpreterCompiler : public Compiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,

View File

@ -24,7 +24,7 @@ namespace xla {
StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) {
const CompileOptions& options) {
// Tensorflow tries to enable the following behaviors in all its threads:
//
// - Denormals are zero (DAZ): roughly, operations treat denormal floats as
@ -48,10 +48,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> LLVMCompiler::Compile(
TF_ASSIGN_OR_RETURN(modules[i],
RunHloPasses(std::move(modules[i]), stream_execs[i][0],
device_allocator));
options.device_allocator));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
RunBackend(std::move(modules[i]), stream_execs[i][0],
device_allocator));
options.device_allocator));
result.push_back(std::move(executable));
}

View File

@ -66,13 +66,14 @@ class LLVMCompiler : public Compiler {
// std::unique_ptr<HloModule> module,
// se::StreamExecutor* stream_exec,
// se::DeviceMemoryAllocator* device_allocator)
using Compiler::Compile;
using Compiler::RunBackend;
using Compiler::RunHloPasses;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
protected:
ModuleHook user_pre_optimization_hook_;

View File

@ -190,11 +190,12 @@ LocalService::CompileExecutables(
// single partition computations are built using `BuildExecutables`, fix it,
// and remove this special case (provided the performance if similar).
if (build_options.num_partitions() == 1) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
executor, build_options.device_allocator(),
build_options.run_backend_only()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
BuildExecutable(proto, std::move(module_config),
execute_backend_.get(), executor,
{build_options.device_allocator(),
build_options.compile_thread_pool()},
build_options.run_backend_only()));
std::vector<std::unique_ptr<Executable>> executables;
executables.push_back(std::move(executable));
return executables;
@ -206,10 +207,12 @@ LocalService::CompileExecutables(
std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
executor);
return BuildExecutables({&proto}, std::move(module_configs),
execute_backend_.get(), {executors},
build_options.device_allocator(),
build_options.run_backend_only());
return BuildExecutables(
/*module_protos=*/{&proto}, std::move(module_configs),
execute_backend_.get(), {executors},
Compiler::CompileOptions{build_options.device_allocator(),
build_options.compile_thread_pool()},
build_options.run_backend_only());
}
}

View File

@ -28,28 +28,24 @@ bool IsUnimplemented(StatusOr<T>& result) {
StatusOr<std::unique_ptr<HloModule>> FailoverCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
auto result =
primary_->RunHloPasses(module->Clone(), stream_exec, device_allocator);
const CompileOptions& options) {
auto result = primary_->RunHloPasses(module->Clone(), stream_exec, options);
if (IsUnimplemented(result)) {
VLOG(2) << "RunHloPasses resulted in " << result.status()
<< ", falling back to secondary backend";
return secondary_->RunHloPasses(std::move(module), stream_exec,
device_allocator);
return secondary_->RunHloPasses(std::move(module), stream_exec, options);
}
return result;
}
StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
auto result =
primary_->RunBackend(module->Clone(), stream_exec, device_allocator);
const CompileOptions& options) {
auto result = primary_->RunBackend(module->Clone(), stream_exec, options);
if (IsUnimplemented(result)) {
VLOG(2) << "RunBackend resulted in " << result.status()
<< ", falling back to secondary backend";
return secondary_->RunBackend(std::move(module), stream_exec,
device_allocator);
return secondary_->RunBackend(std::move(module), stream_exec, options);
}
return result;
}
@ -57,7 +53,7 @@ StatusOr<std::unique_ptr<Executable>> FailoverCompiler::RunBackend(
StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) {
const CompileOptions& options) {
std::vector<std::unique_ptr<Executable>> result;
std::vector<std::unique_ptr<HloModule>> modules =
module_group->ConsumeModules();
@ -67,17 +63,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
return Unimplemented(
"Model partitioning not implemented for the failover compiler!");
}
auto executable = [stream_execs, device_allocator, i,
auto executable = [stream_execs, &options, i,
this](std::unique_ptr<HloModule> module)
-> StatusOr<std::unique_ptr<Executable>> {
TF_ASSIGN_OR_RETURN(
auto processed_module,
primary_->RunHloPasses(std::move(module), stream_execs[i][0],
device_allocator));
TF_ASSIGN_OR_RETURN(
auto result,
primary_->RunBackend(std::move(processed_module), stream_execs[i][0],
device_allocator));
TF_ASSIGN_OR_RETURN(auto processed_module,
primary_->RunHloPasses(std::move(module),
stream_execs[i][0], options));
TF_ASSIGN_OR_RETURN(auto result,
primary_->RunBackend(std::move(processed_module),
stream_execs[i][0], options));
return result;
}(modules[i]->Clone());
@ -85,13 +79,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> FailoverCompiler::Compile(
VLOG(2) << "Compile resulted in " << executable.status()
<< ", falling back to secondary backend";
TF_ASSIGN_OR_RETURN(
modules[i],
secondary_->RunHloPasses(std::move(modules[i]), stream_execs[i][0],
device_allocator));
TF_ASSIGN_OR_RETURN(
executable,
secondary_->RunBackend(std::move(modules[i]), stream_execs[i][0],
device_allocator));
modules[i], secondary_->RunHloPasses(std::move(modules[i]),
stream_execs[i][0], options));
TF_ASSIGN_OR_RETURN(executable,
secondary_->RunBackend(std::move(modules[i]),
stream_execs[i][0], options));
}
if (!executable.ok()) {

View File

@ -51,16 +51,16 @@ class FailoverCompiler final : public Compiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,

View File

@ -87,16 +87,16 @@ class MlirCompilerImpl : public MlirCompiler {
public:
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) override;
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
@ -155,12 +155,12 @@ std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
StatusOr<std::unique_ptr<HloModule>> MlirCompilerImpl::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
const CompileOptions& options) {
// Until we find a reason to do something different, run the same passes
// that the normal GPU backend runs.
gpu::NVPTXCompiler xla_compiler;
TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec,
device_allocator));
options.device_allocator));
TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get()));
return std::move(module);
@ -454,7 +454,7 @@ StatusOr<std::unique_ptr<gpu::KernelThunk>> TransformKernelToXlaThunk(
StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
const CompileOptions& options) {
// Determine the HLO schedule, which is an ordering of HLO instructions. This
// is used by buffer assignment to enable buffer reuse, and the same ordering
// must also be used to determine the thunk launch schedule.
@ -595,7 +595,7 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_execs,
se::DeviceMemoryAllocator* device_allocator) {
const CompileOptions& options) {
return Unimplemented("Not yet implemented in MLIR compiler");
}

View File

@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
const Compiler::CompileOptions& options, bool run_backend_only) {
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
// Dump computation proto state if flag is set.
@ -387,17 +387,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
std::vector<std::unique_ptr<Executable>> executables;
if (!run_backend_only) {
TF_ASSIGN_OR_RETURN(
executables,
backend->compiler()->Compile(std::move(module_group),
std::move(executors), device_allocator));
TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
std::move(module_group),
std::move(executors), options));
} else {
auto modules = module_group->ConsumeModules();
for (std::unique_ptr<HloModule>& module : modules) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(std::move(module), executors[0][0],
device_allocator));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(
std::move(module), executors[0][0], options));
executables.push_back(std::move(executable));
}
}
@ -710,7 +708,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
BuildExecutables(module_protos, std::move(module_configs),
execute_backend_.get(), all_executors,
/*device_allocator=*/nullptr));
{/*device_allocator=*/nullptr}));
std::vector<Executable*> executable_ptrs;
executable_ptrs.reserve(executables.size());
for (const auto& executable : executables) {
@ -810,7 +808,7 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
se::StreamExecutor* executor, const Compiler::CompileOptions& options,
bool run_backend_only) {
VLOG(1) << StrFormat(
"BuildExecutable on service %p with serialized module proto: %s", this,
@ -822,14 +820,13 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
if (!run_backend_only) {
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
std::move(module), executor, options));
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(
std::move(module), executor, device_allocator));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(std::move(module), executor, options));
const auto& debug_opts = module_config->debug_options();
if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
@ -875,7 +872,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
BuildExecutable(arg->computation(), std::move(module_config),
execute_backend_.get(),
execute_backend_->default_stream_executor(),
/*device_allocator=*/nullptr));
{/*device_allocator=*/nullptr}));
*result->mutable_handle() = compilation_cache_.Insert(std::move(executable));

View File

@ -235,8 +235,7 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator = nullptr,
se::StreamExecutor* executor, const Compiler::CompileOptions& options,
bool run_backend_only = false);
// Same as BuildExecutable() above, but builds a list of Executables for the
@ -245,8 +244,7 @@ class Service : public ServiceInterface {
const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
se::DeviceMemoryAllocator* device_allocator,
bool run_backend_only = false);
const Compiler::CompileOptions& options, bool run_backend_only = false);
// Runs the given executable with the given arguments and register the result
// in the allocation tracker. The handle of the result from the tracker is

View File

@ -102,27 +102,8 @@ bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
}
}
// Returns a sharding where each tuple element is chosen as the more specific
// one of the corresponding elements in a and b. Requires a an b to have the
// same tuple nesting.
HloSharding MergeForMoreSpecificSharding(const HloSharding& a,
const HloSharding& b) {
if (a.IsTuple()) {
HloSharding result = a;
CHECK(b.IsTuple());
CHECK_EQ(a.tuple_elements().size(), b.tuple_elements().size());
for (int64 i = 0; i < result.tuple_elements().size(); ++i) {
result.tuple_elements()[i] = MergeForMoreSpecificSharding(
a.tuple_elements()[i], b.tuple_elements()[i]);
}
return result;
}
return IsShardingMoreSpecific(a, b) ? a : b;
}
// Tries to refine `to_merge` by combining with `old`. Returns if the final
// `to_merge` is more specific than `old`. May combine partial sharding in
// addition to MergeForMoreSpecificSharding().
// `to_merge` is more specific than `old`.
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
bool may_combine_partial_sharding) {
if (old.IsTuple()) {
@ -1093,8 +1074,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
}
auto sharding = instruction->operand(0)->sharding();
if (instruction->has_sharding()) {
sharding =
MergeForMoreSpecificSharding(sharding, instruction->sharding());
MergeSharding(instruction->sharding(), &sharding,
may_combine_partial_sharding);
}
return MaybeImproveInstructionSharding(std::move(sharding), instruction,
may_combine_partial_sharding);
@ -1320,6 +1301,12 @@ absl::optional<HloSharding> GetShardingFromUser(
return hlo_sharding_util::ReshapeSharding(
user.shape(), instruction.shape(), user.sharding());
}
case HloOpcode::kPad: {
if (&instruction != user.operand(0)) {
return absl::nullopt;
}
return user.sharding();
}
case HloOpcode::kSlice: {
return user.sharding();
}
@ -1673,8 +1660,10 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
// If instruction is a while, or the root or a parameter of a while body,
// then propagate its sharding to the while instruction, to its body root,
// and to its condition parameter.
std::function<void(HloInstruction*)> maybe_computation_propagation =
[&](HloInstruction* instruction) {
std::function<void(HloInstruction*, absl::flat_hash_set<HloInstruction*>*)>
maybe_computation_propagation = [&](HloInstruction* instruction,
absl::flat_hash_set<HloInstruction*>*
changed) {
auto propagate_to_instruction = [&](HloInstruction* search_inst) {
auto related_instructions = get_related_instructions(search_inst);
if (absl::c_count(related_instructions, instruction)) {
@ -1683,7 +1672,8 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
inst->sharding() != instruction->sharding()) {
VLOG(2) << "Add computation sharding: " << inst->name();
inst->set_sharding(instruction->sharding());
maybe_computation_propagation(inst);
changed->insert(inst);
maybe_computation_propagation(inst, changed);
}
}
}
@ -1785,6 +1775,14 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
for (const HloInstruction* instruction : instructions) {
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
}
auto clear_cache = [&](HloInstruction* hlo) {
for (auto operand : hlo->operands()) {
already_inferred_from_users.erase(operand);
}
for (auto user : hlo->users()) {
already_inferred_from_operands.erase(user);
}
};
// First iterate the HLO graph in post order taking shardings from
// operands.
for (HloInstruction* instruction : instructions) {
@ -1799,12 +1797,11 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
any_changed = true;
VLOG(2) << "Add sharding (forward-pass): "
<< instruction->ToString();
maybe_computation_propagation(instruction);
for (auto operand : instruction->operands()) {
already_inferred_from_users.erase(operand);
}
for (auto user : instruction->users()) {
already_inferred_from_operands.erase(user);
absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
maybe_computation_propagation(instruction, &changed_in_comp_prop);
clear_cache(instruction);
for (auto hlo : changed_in_comp_prop) {
clear_cache(hlo);
}
changed_last_iter = true;
}
@ -1823,12 +1820,11 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
++inferred_from_user_counter;
any_changed = true;
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
maybe_computation_propagation(*it);
for (auto operand : (*it)->operands()) {
already_inferred_from_users.erase(operand);
}
for (auto user : (*it)->users()) {
already_inferred_from_operands.erase(user);
absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
maybe_computation_propagation(*it, &changed_in_comp_prop);
clear_cache(*it);
for (auto hlo : changed_in_comp_prop) {
clear_cache(hlo);
}
changed_last_iter = true;
}

View File

@ -514,6 +514,26 @@ ENTRY %pad {
op::Sharding("{devices=[2,2]0,1,2,3}"));
}
TEST_F(ShardingPropagationTest, PadBackwardPass) {
const char* const hlo_string = R"(
HloModule module
ENTRY %pad {
%input = f32[11,17]{1,0} parameter(0)
%copy = f32[11,17]{1,0} copy(%input)
%pad_value = f32[] parameter(1)
%pad = f32[27,51]{1,0} pad(%copy, %pad_value), padding=2_4_1x1_1_2,
sharding={devices=[2,2]0,1,2,3}
ROOT %result = f32[27,51]{1,0} copy(%pad)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "copy"),
op::Sharding("{devices=[2,2]0,1,2,3}"));
}
TEST_F(ShardingPropagationTest, PartialReplicatedPadForwardPass) {
const char* const hlo_string = R"(
HloModule module
@ -856,40 +876,41 @@ TEST_F(ShardingPropagationTest, While) {
HloModule module
%cond {
%vars.cond = (u32[], f32[10]{0}) parameter(0)
%count.cond = u32[] get-tuple-element((u32[], f32[10]{0}) %vars.cond), index=0
%vars.cond = (u32[], f32[10,10]) parameter(0)
%count.cond = u32[] get-tuple-element((u32[], f32[10,10]) %vars.cond), index=0
%limit = u32[] constant(10)
ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT
}
%body {
%vars = (u32[], f32[10]{0}) parameter(0)
%vars = (u32[], f32[10,10]) parameter(0)
%count = u32[] get-tuple-element(%vars), index=0
%acc = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %vars), index=1
%acc = f32[10,10] get-tuple-element((u32[], f32[10,10]) %vars), index=1
%one = u32[] constant(1)
%count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated}
%acc.1 = f32[10]{0} add(f32[10]{0} %acc, f32[10]{0} %acc)
ROOT %tuple = (u32[], f32[10]{0}) tuple(u32[] %count.1, f32[10]{0} %acc.1)
%acc.1 = f32[10,10] add(f32[10,10] %acc, f32[10,10] %acc)
ROOT %tuple = (u32[], f32[10,10]) tuple(u32[] %count.1, f32[10,10] %acc.1)
}
ENTRY %entry {
%p0 = f32[10]{0} parameter(0)
%p0.copy = f32[10]{0} copy(f32[10]{0} %p0)
%p1 = f32[10]{0} parameter(1)
%p0 = f32[10,10] parameter(0)
%p0.copy = f32[10,10] copy(f32[10,10] %p0)
%p1 = f32[10,10] parameter(1)
%zero = u32[] constant(0)
%init = (u32[], f32[10]{0}) tuple(u32[] %zero, f32[10]{0} %p0.copy)
%while = (u32[], f32[10]{0}) while((u32[], f32[10]{0}) %init),
%init = (u32[], f32[10,10]) tuple(u32[] %zero, f32[10,10] %p0.copy)
%while = (u32[], f32[10,10]) while((u32[], f32[10,10]) %init),
body=%body, condition=%cond
%res = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %while), index=1
%prev = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %init), index=1
%res.1 = f32[10]{0} multiply(f32[10]{0} %res, %prev)
ROOT %res_tuple = (f32[10]{0}) tuple(f32[10]{0} %res.1)
%res = f32[10,10] get-tuple-element((u32[], f32[10,10]) %while), index=1
%prev = f32[10,10] get-tuple-element((u32[], f32[10,10]) %init), index=1
%res.1 = f32[10,10] multiply(f32[10,10] %res, %prev)
ROOT %res_tuple = (f32[10,10]) tuple(f32[10,10] %res.1)
})";
auto while_is_sharded = [this](HloModule* module,
const HloSharding& sharding) {
TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation(/*is_spmd=*/true).Run(module));
EXPECT_TRUE(changed);
auto while_instr = FindInstruction(module, "while");
EXPECT_NE(nullptr, while_instr);
@ -911,7 +932,7 @@ ENTRY %entry {
auto body_root = FindInstruction(module.get(), "tuple");
EXPECT_NE(nullptr, body_root);
auto sharding =
ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie();
ParseSharding("{{replicated}, {devices=[2,1]0,1}}").ConsumeValueOrDie();
body_root->set_sharding(sharding);
while_is_sharded(module.get(), sharding);
}
@ -921,11 +942,30 @@ ENTRY %entry {
ParseAndReturnVerifiedModule(hlo_string));
auto acc_1 = FindInstruction(module.get(), "acc.1");
EXPECT_NE(nullptr, acc_1);
acc_1->set_sharding(ParseSharding("{devices=[2]0,1}").ConsumeValueOrDie());
acc_1->set_sharding(
ParseSharding("{devices=[2,1]0,1}").ConsumeValueOrDie());
while_is_sharded(
module.get(),
ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie());
while_is_sharded(module.get(),
ParseSharding("{{replicated}, {devices=[2,1]0,1}}")
.ConsumeValueOrDie());
}
{
// Merge partial sharding from operand and body.
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto acc_1 = FindInstruction(module.get(), "acc.1");
EXPECT_NE(nullptr, acc_1);
acc_1->set_sharding(
ParseSharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")
.ConsumeValueOrDie());
auto p0 = FindInstruction(module.get(), "p0");
p0->set_sharding(
ParseSharding("{devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}")
.ConsumeValueOrDie());
while_is_sharded(module.get(),
ParseSharding("{{replicated}, {devices=[2,2]0,1,2,3}}")
.ConsumeValueOrDie());
}
}

View File

@ -182,6 +182,10 @@ class ConvolutionVisitor {
return permute_dims[id];
}
int64 ReverseDimLookUp(absl::Span<const int64> permute_dims, int64 id) {
return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
}
HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
HloInstruction* instr, int64 depth);
@ -215,9 +219,10 @@ class ConvolutionVisitor {
// Limit on batch size to apply this technique on.
int64 limit_on_batch_size_;
// We choose the new batch size to be a constant so that space-to-batch
// propagation through several convolutional layers is consistent.
static constexpr int64 kNewBatchSize = 8;
// We choose the new batch size to be kNumSplits times that of the old batch
// so that space-to-batch propagation through several convolutional layers is
// consistent.
static constexpr int64 kNumSplits = 8;
// Depth for searching reduce window
static constexpr int64 kReduceWindowSearchDepth = 10;
@ -301,17 +306,12 @@ bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
if (old_batch_size > limit_on_batch_size_) {
return false;
}
// We currently only cater to evenly divisible cases.
if (kNewBatchSize % old_batch_size != 0) {
return false;
}
VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
// If the ratio is not within the 2X range, we can't Halo Pad from the next
// split.
if (c.halo_size > CeilOfRatio(c.spatial_size, num_splits)) {
if (c.halo_size > CeilOfRatio(c.spatial_size, kNumSplits)) {
return false;
}
VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
@ -323,6 +323,24 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
int64 activations_batch_dim, int64 old_batch_size, int64 low_padding,
int64 high_padding, int64 halo_size, int64 original_split_dim_size,
HloInstruction* pad_val) {
const int64 original_batch_size =
activations->shape().dimensions(activations_batch_dim) / kNumSplits;
if (original_batch_size > 1) {
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
activations->shape().dimensions().end());
new_dimensions[activations_batch_dim] = kNumSplits;
new_dimensions.insert(new_dimensions.begin() + activations_batch_dim,
original_batch_size);
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(activations,
MakeReshapeHlo(new_dimensions, activations));
spatial_dimension_to_split++;
activations_batch_dim++;
}
const int64 rank = activations->shape().rank();
const int64 spatial_split_size =
activations->shape().dimensions(spatial_dimension_to_split);
@ -415,6 +433,21 @@ StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
TF_ASSIGN_OR_RETURN(activations, MakeConcatHlo({activations, halo_region},
spatial_dimension_to_split));
}
if (original_batch_size > 1) {
std::vector<int64> new_dimensions(activations->shape().dimensions().begin(),
activations->shape().dimensions().end());
new_dimensions[activations_batch_dim] = original_batch_size * kNumSplits;
new_dimensions.erase(new_dimensions.begin() + activations_batch_dim - 1);
// Reshape the output of the new conv into the old convolutions shape.
TF_ASSIGN_OR_RETURN(activations,
MakeReshapeHlo(new_dimensions, activations));
spatial_dimension_to_split++;
activations_batch_dim++;
}
VLOG(1) << "HaloDuplicated activations " << activations->ToString();
return activations;
}
@ -424,17 +457,20 @@ ConvolutionVisitor::BringSpaceNextToBatch(
HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
int64& spatial_dimension_to_split, int64& activations_batch_dim,
bool is_backprop) {
std::vector<int64> transpose_dims;
ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
if (spatial_dimension_to_split != activations_batch_dim + 1) {
std::vector<int64> transpose_dims(activations->shape().rank());
if (spatial_dimension_to_split == activations_batch_dim + 1) {
absl::c_iota(transpose_dims, 0);
} else {
ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
int64 pushed_counter = 0;
int64 new_batch_dim, new_spatial_dim;
int64 dim_counter = 0;
for (int i = 0; i < activations->shape().rank(); ++i) {
if (i == activations_batch_dim) {
continue;
}
if (i == spatial_dimension_to_split) {
transpose_dims.push_back(activations_batch_dim);
transpose_dims[dim_counter++] = activations_batch_dim;
new_batch_dim = pushed_counter;
pushed_counter++;
new_spatial_dim = pushed_counter;
@ -452,7 +488,7 @@ ConvolutionVisitor::BringSpaceNextToBatch(
}
}
}
transpose_dims.push_back(i);
transpose_dims[dim_counter++] = i;
pushed_counter++;
}
@ -460,14 +496,14 @@ ConvolutionVisitor::BringSpaceNextToBatch(
spatial_dimension_to_split = new_spatial_dim;
TF_ASSIGN_OR_RETURN(activations,
MakeTransposeHlo(activations, transpose_dims));
}
if (is_backprop) {
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
} else {
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
if (is_backprop) {
new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
} else {
new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
}
dim_numbers = new_dim_numbers;
}
dim_numbers = new_dim_numbers;
return SpaceNextToBatchDetails{activations, transpose_dims};
}
@ -586,12 +622,23 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
VLOG(1) << "Checking if conv is supported for propagation "
<< consumer->ToString();
if (IsConvSuitableForSpaceToBatch(consumer)) {
for (int64 i = 0; i < consumer->operand_count(); ++i) {
auto old_producer = consumer->mutable_operand(i);
if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
return false;
}
if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
return false;
}
auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
// Make sure that the space dimension is the same across the producer
// and consumer.
if (consumer->convolution_dimension_numbers().input_spatial_dimensions(
get_chosen_spatial_dim(consumer)) != dim_map_val_op_0.second) {
return false;
}
// Make sure that the batch dimension is the same across the producer
// and consumer.
if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
dim_map_val_op_0.first) {
return false;
}
return true;
}
@ -611,13 +658,35 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
VLOG(2) << "Checking for backprop filter conv operands "
<< consumer->operand_count();
if (!old_to_new_instrs_.contains(consumer->mutable_operand(1))) {
auto activations = consumer->mutable_operand(0);
auto kernel = consumer->mutable_operand(1);
if (!old_to_new_instrs_.contains(kernel)) {
VLOG(2) << "Backprop filter conv not ready for propagation because of "
"kernel is not space-to-batched";
return false;
}
if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
if (!old_to_new_instrs_.contains(activations)) {
const int64 lhs_batch = activations->shape().dimensions(
consumer->convolution_dimension_numbers().input_feature_dimension());
auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
const int64 old_batch_dim = dim_map_val_op_1.first;
auto second_operand = old_to_new_instrs_[kernel];
auto permute_dims_second_operand =
instr_to_dim_permute_map_[second_operand];
const int64 new_batch_dim =
DimLookUp(permute_dims_second_operand, old_batch_dim);
const int64 rhs_batch = second_operand->shape().dimensions(new_batch_dim);
// Because we want to convert activations into a space-to-batched version
// only for backprop filter convolutions, we want to make sure that the
// batch dimensions (feature dimensions, technically) are same sized.
// Since RHS is already space-to-batched, we need to account for it too.
if (rhs_batch != kNumSplits * lhs_batch) {
return false;
}
// If activations have not been propagated through, we can do
// space-to-batch on them provided kernel has been propagated.
VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
@ -625,10 +694,10 @@ bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
return true;
}
auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
auto first_operand = old_to_new_instrs_[activations];
auto dim_map_val_op_0 = instr_to_dim_map_[activations];
auto second_operand = old_to_new_instrs_[kernel];
auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
auto permute_dims_second_operand =
@ -1119,7 +1188,7 @@ StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
Window new_win;
for (int64 i = 0; i < consumer->window().dimensions().size(); ++i) {
auto dim = DimLookUp(permute_dims, i);
auto dim = ReverseDimLookUp(permute_dims, i);
new_win.add_dimensions();
new_win.mutable_dimensions(i)->set_stride(
consumer->window().dimensions(dim).stride());
@ -1339,7 +1408,9 @@ StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
const int64 new_space_size = new_shape.dimensions(new_space_dim);
const int64 old_batch_size = old_shape.dimensions(old_batch_dim);
const int64 old_space_size = old_shape.dimensions(old_space_dim);
CHECK_EQ(new_batch_size % old_batch_size, 0);
CHECK_EQ(new_batch_size % old_batch_size, 0)
<< " New batch size " << new_batch_size << " old batch size "
<< old_batch_size;
const int64 num_splits = new_batch_size / old_batch_size;
// Build a constant PRED to decide which elements in the split dimension
// are from halo.
@ -1394,8 +1465,10 @@ StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
CHECK(old_to_new_instrs_.contains(old_instr));
auto new_instr = old_to_new_instrs_[old_instr];
VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
<< old_space_dim << " new_instr " << new_instr->ToString()
<< " permute dims " << instr_to_dim_permute_map_.count(new_instr);
<< old_space_dim << " old_instr " << old_instr->ToString()
<< "\n new_instr " << new_instr->ToString() << " permute dims "
<< instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
<< old_batch_size;
CHECK(instr_to_dim_permute_map_.contains(new_instr));
auto permute_dims = instr_to_dim_permute_map_[new_instr];
const int64 batch_dim = DimLookUp(permute_dims, old_batch_dim);
@ -1565,6 +1638,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
c.spatial_dimension_to_split, activations_batch_dim));
activations_new = retval.instr;
std::vector<int64> trans_dims = retval.transpose_dims;
CHECK(!trans_dims.empty());
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations_new->shape().element_type())));
@ -1578,8 +1652,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
const int64 num_splits = kNumSplits;
const int64 output_offsets = convolution->shape().dimensions(
permuted_conv_dims_numbers.output_spatial_dimensions(
get_chosen_spatial_dim(convolution)));
@ -1614,6 +1687,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
activations_new->shape().dimensions().end());
const int64 reshaped_space_size =
new_space_size * new_batch_size / old_batch_size;
VLOG(3) << "Increasing the spatial size while propagating new_batch_size "
<< new_batch_size << " old_batch_size " << old_batch_size;
new_dimensions[c.spatial_dimension_to_split] = reshaped_space_size;
new_dimensions[activations_batch_dim] = old_batch_size;
@ -1621,10 +1696,12 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_activations,
MakeReshapeHlo(new_dimensions, activations_new));
VLOG(3) << "First reshape done";
PaddingConfig padding_config =
MakeNoPaddingConfig(reshaped_activations->shape().dimensions_size());
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_high(spatial_split_size * new_batch_size -
->set_edge_padding_high(spatial_split_size * new_batch_size /
old_batch_size -
reshaped_space_size);
padding_config.mutable_dimensions(c.spatial_dimension_to_split)
->set_edge_padding_low(0);
@ -1647,6 +1724,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
reshaped_activations,
MakeReshapeHlo(reshape_back_dims, reshaped_activations));
VLOG(3) << "Second reshape done";
TF_ASSIGN_OR_RETURN(
activations_new,
HaloDuplicateWithSlice(
@ -1664,6 +1743,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
// additional space available, and adjust the required slice size (and
// thereby the halo size).
if (spatial_split_size < new_space_size) {
VLOG(3) << "Decreasing the spatial size while propagating";
const int64 additional_space_present = spatial_split_size % c.stride;
spatial_split_size = new_space_size;
slice_size =
@ -1758,6 +1838,7 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
activations = retval.instr;
std::vector<int64> transpose_dims = retval.transpose_dims;
CHECK(!transpose_dims.empty());
// Because we are splitting the spatial dimension, if convolution needed
// padding in the spatial dimension, we materialize it.
if (high_padding || low_padding) {
@ -1774,7 +1855,9 @@ ConvolutionVisitor::SplitSpace(HloInstruction* activations,
MakePadHlo(activations, padding, padding_config));
}
VLOG(1) << "Initial padded activations shape "
<< activations->shape().ToString();
<< activations->shape().ToString() << " old_batch_size "
<< old_batch_size << " activations_batch_dim "
<< activations_batch_dim;
// Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
// and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
@ -1829,7 +1912,10 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
CHECK(old_to_new_instrs_.contains(kernel_old));
auto kernel_new = old_to_new_instrs_[kernel_old];
auto permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
HloInstruction* activations_new = nullptr;
bool activations_locally_space_to_batched = false;
// If activations were no space-to-batched, we space-to-batch them below.
if (!old_to_new_instrs_.contains(activations_old)) {
VLOG(1) << "Space-to-batching activations to enable space-to-depth";
@ -1838,28 +1924,34 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
instr_to_dim_map_[activations_old] =
std::make_pair(prev_feature_dim, prev_batch_dim);
int64 activations_batch_dim = original_conv_dims.input_feature_dimension();
const int64 old_batch_size =
activations_old->shape().dimensions(activations_batch_dim);
const int64 num_splits = kNewBatchSize / old_batch_size;
const int64 new_kernel_space_dim =
DimLookUp(permute_dims_kernel, kernel_space_dim);
const int64 new_kernel_split_dim_size =
kernel_new->shape().dimensions(kernel_space_dim);
kernel_new->shape().dimensions(new_kernel_space_dim);
const int64 needed_spatial_size = rhs_dilation * new_kernel_split_dim_size;
const int64 pad_size =
needed_spatial_size * num_splits - old_split_dim_size;
needed_spatial_size * kNumSplits - old_split_dim_size;
ConvolutionDimensionNumbers tmp_dim_numbers;
tmp_dim_numbers = original_conv_dims;
TF_ASSIGN_OR_RETURN(
auto retval,
SplitSpace(activations_old, tmp_dim_numbers, old_space_dim,
activations_batch_dim,
old_batch_dim,
/*high_padding=*/pad_size, /*low_padding=*/0,
needed_spatial_size, num_splits, /*is_backprop=*/true));
needed_spatial_size, kNumSplits, /*is_backprop=*/true));
old_to_new_instrs_[activations_old] = retval.first;
instr_to_dim_permute_map_[retval.first] = retval.second;
VLOG(3) << "Edited conv dims " << original_conv_dims.DebugString();
std::vector<int64> reversed_transpose_dims(retval.second.size());
for (int64 i = 0; i < retval.second.size(); ++i) {
reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
}
instr_to_dim_permute_map_[retval.first] = reversed_transpose_dims;
VLOG(3) << "New Activations " << retval.first->ToString();
activations_locally_space_to_batched = true;
}
CHECK(old_to_new_instrs_.contains(activations_old));
@ -1884,7 +1976,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
i, DimLookUp(permute_dims,
original_conv_dims.input_spatial_dimensions(i)));
permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
i, DimLookUp(permute_dims,
i, DimLookUp(permute_dims_kernel,
original_conv_dims.kernel_spatial_dimensions(i)));
}
@ -1905,10 +1997,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
const int64 kernel_input_feature_dim = DimLookUp(
permute_dims, original_conv_dims.kernel_input_feature_dimension());
permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
const int64 kernel_output_feature_dim = DimLookUp(
permute_dims, original_conv_dims.kernel_output_feature_dimension());
const int64 kernel_output_feature_dim =
DimLookUp(permute_dims_kernel,
original_conv_dims.kernel_output_feature_dimension());
permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
kernel_input_feature_dim);
@ -1931,7 +2024,8 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
VLOG(1) << "Propagating on conv activations_batch_dim "
<< activations_batch_dim << " spatial_dimension_to_split "
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size;
<< spatial_dimension_to_split << " old_batch_size " << old_batch_size
<< " new_split_dim_size " << new_split_dim_size;
TF_ASSIGN_OR_RETURN(
auto retval,
@ -1939,6 +2033,7 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
spatial_dimension_to_split, activations_batch_dim,
/*is_backprop=*/true));
std::vector<int64> transpose_dims = retval.transpose_dims;
CHECK(!transpose_dims.empty());
activations_new = retval.instr;
VLOG(1) << "Activations_new post BringSpaceNextToBatch "
@ -1949,13 +2044,15 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(activations_new->shape().element_type())));
// Select activations correctly by masking additional space.
TF_ASSIGN_OR_RETURN(
activations_new,
SelectValidPortion(activations_new, activations_old, select_val,
activations_batch_dim, spatial_dimension_to_split,
old_batch_dim, old_space_dim));
if (!activations_locally_space_to_batched) {
// Select activations correctly by masking additional space.
TF_ASSIGN_OR_RETURN(
activations_new,
SelectValidPortion(activations_new, activations_old, select_val,
activations_batch_dim, spatial_dimension_to_split,
old_batch_dim, old_space_dim));
}
VLOG(3) << "Selecting the valid kernel area";
// Select kernel correctly by masking additional space.
TF_ASSIGN_OR_RETURN(
kernel_new,
@ -2238,7 +2335,6 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
VLOG(1) << "spatial size " << c.spatial_size;
const int64 num_splits = kNewBatchSize / old_batch_size;
auto original_conv = convolution;
const int64 output_spatial_dim = dim_numbers.output_spatial_dimensions(
@ -2246,13 +2342,13 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
const int64 output_offsets =
convolution->shape().dimensions(output_spatial_dim);
const int64 output_offsets_per_split =
CeilOfRatio(output_offsets, num_splits);
CeilOfRatio(output_offsets, kNumSplits);
int64 spatial_split_size =
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
// Keep increasing the split size so that overall size isn't smaller than the
// original spatial dimension.
while (spatial_split_size * num_splits - c.spatial_size < 0) {
while (spatial_split_size * kNumSplits - c.spatial_size < 0) {
spatial_split_size += c.stride;
}
@ -2276,12 +2372,12 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
const int64 slice_size = spatial_split_size + c.halo_size;
// Pad spatial dim.
const int64 pad_size = spatial_split_size * num_splits - c.spatial_size;
const int64 pad_size = spatial_split_size * kNumSplits - c.spatial_size;
VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
<< c.stride << " slice_size " << slice_size;
VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimension_to_split
<< " num_splits " << num_splits << " kernel_spatial_dim_size "
<< " num_splits " << kNumSplits << " kernel_spatial_dim_size "
<< c.kernel_spatial_dim_size;
int64 spatial_dimension_to_split = c.spatial_dimension_to_split;
TF_ASSIGN_OR_RETURN(
@ -2292,7 +2388,7 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
/*low_padding=*/c.base_dilation_factor == 1
? c.inherent_low_padding
: 0,
spatial_split_size, num_splits));
spatial_split_size, kNumSplits));
HloInstruction* batch_increased_reshape = retval.first;
convolution->SetupDerivedInstruction(batch_increased_reshape);

View File

@ -317,5 +317,27 @@ TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
}
}
TEST_F(DynamismInferenceTest, InferThroughPad) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
// Test the analysis on a gather.
auto operand1 = ConstantR1<int32>(&b, {1, 2});
auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
PaddingConfig padding_config;
padding_config.add_dimensions()->set_edge_padding_high(1);
// After pad the value is [constant, constant, parameter].
auto pad = Pad(operand1, parameter, padding_config);
ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
// Everything is constant, result is also contant.
EXPECT_FALSE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({0}));
EXPECT_FALSE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({1}));
EXPECT_TRUE(
ComputeDynamismLiteral(client, pad, &b).ValueOrDie().Get<bool>({2}));
}
}
} // namespace
} // namespace xla

View File

@ -57,7 +57,8 @@ class GpuDummyCompiler : public GpuCompiler {
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
GpuVersion gpu_version, se::StreamExecutor* stream_exec,
bool relocatable) {
if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module);
}

View File

@ -20,9 +20,8 @@ op {
"A `Tensor` of type T. An alias of `x`. The content "
"of `y` is undefined if there are duplicates in `i`."
}
summary: <<END
Adds v into specified rows of x.
summary: "Adds v into specified rows of x."
description: <<END
Computes y = x; y[i, :] += v; return y.
END
}

View File

@ -1,8 +1,8 @@
op {
graph_op_name: "TopKUnique"
summary: "Returns the TopK unique values in the array in sorted order. The"
summary: "Returns the TopK unique values in the array in sorted order."
description: <<END
running time is proportional to the product of K and the input
The running time is proportional to the product of K and the input
size. Sorting the whole array is more efficient for sufficiently large
values of K. The median-of-medians algorithm is probably faster, but
difficult to implement efficiently in XLA. If there are fewer than K

View File

@ -1,10 +1,11 @@
op {
graph_op_name: "TopKWithUnique"
summary: "Returns the TopK values in the array in sorted order. This is a combination"
summary: "Returns the TopK values in the array in sorted order."
description: <<END
of MakeUnique and TopKUnique. The returned top-K will have its lower bits
replaced by iota, thus it will be close to the original value but not exactly
the same. The running time is proportional to the product of K and the input
size. NaNs are never returned. Subnormal numbers are flushed to zero.
This is a combination of MakeUnique and TopKUnique. The returned top-K will
have its lower bits replaced by iota, thus it will be close to the original
value but not exactly the same. The running time is proportional to the product
of K and the input size. NaNs are never returned. Subnormal numbers are flushed
to zero.
END
}

View File

@ -705,10 +705,10 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
return Status::OK();
}
Status EagerContext::AddFunctionDefWithDebugInfo(
const FunctionDef& fdef, const Graph* graph_with_debug_info) {
Status EagerContext::AddFunctionDefWithStackTraces(
const FunctionDef& fdef, const StackTracesMap& stack_traces) {
return AddFunctionDef(fdef, FunctionDefLibrary(),
/* add_to_local_only=*/false, graph_with_debug_info);
/* add_to_local_only=*/false, stack_traces);
}
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
@ -719,7 +719,7 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
const FunctionDefLibrary& library,
const bool add_to_local_only,
const Graph* graph_with_debug_info) {
const StackTracesMap& stack_traces) {
bool is_first_ref = false;
{
mutex_lock l(cache_mu_);
@ -753,8 +753,7 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
is_first_ref = registered_function->RefCountIsOne();
}
if (is_first_ref) {
TF_RETURN_IF_ERROR(
func_lib_def_.AddFunctionDef(fdef, graph_with_debug_info));
TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
if (!add_to_local_only) {
return MaybeRegisterFunctionRemotely(fdef);

Some files were not shown because too many files have changed in this diff Show More